77import argparse
88import copy
99import os
10+ from typing import List , Dict , Iterator , Tuple , Any
1011
1112import torch
1213from torch import nn
@@ -106,28 +107,46 @@ def __init__(self, args, task, models):
106107 self .tokenizer = encoders .build_tokenizer (args )
107108 self .bpe = encoders .build_bpe (args )
108109
110+ self .max_positions = utils .resolve_max_positions (
111+ self .task .max_positions (), * [model .max_positions () for model in models ]
112+ )
113+
109114 # this is useful for determining the device
110115 self .register_buffer ('_float_tensor' , torch .tensor ([0 ], dtype = torch .float ))
111116
112117 @property
113118 def device (self ):
114119 return self ._float_tensor .device
115120
116- def translate (self , sentence : str , beam : int = 5 , verbose : bool = False , ** kwargs ) -> str :
117- return self .sample (sentence , beam , verbose , ** kwargs )
121+ def translate (self , sentences : List [ str ] , beam : int = 5 , verbose : bool = False , ** kwargs ) -> List [ str ] :
122+ return self .sample (sentences , beam , verbose , ** kwargs )
118123
119- def sample (self , sentence : str , beam : int = 1 , verbose : bool = False , ** kwargs ) -> str :
120- input = self .encode (sentence )
121- hypo = self .generate (input , beam , verbose , ** kwargs )[0 ]['tokens' ]
122- return self .decode (hypo )
124+ def sample (self , sentences : List [str ], beam : int = 1 , verbose : bool = False , ** kwargs ) -> List [str ]:
125+ if isinstance (sentences , str ):
126+ return self .sample ([sentences ], beam = beam , verbose = verbose , ** kwargs )[0 ]
127+ tokenized_sentences = [self .encode (sentence ) for sentence in sentences ]
128+ batched_hypos = self .generate (tokenized_sentences , beam , verbose , ** kwargs )
129+ return [self .decode (hypos [0 ]['tokens' ]) for hypos in batched_hypos ]
123130
124- def score (self , sentence : str , ** kwargs ):
131+ def score (self , sentences : List [str ], ** kwargs ):
132+ if isinstance (sentences , str ):
133+ return self .score ([sentences ], ** kwargs )[0 ]
125134 # NOTE: this doesn't support translation tasks currently
126- input = self .encode (sentence )
127- return self .generate (input , score_reference = True , ** kwargs )[0 ]
128-
129- def generate (self , tokens : torch .LongTensor , beam : int = 5 , verbose : bool = False , ** kwargs ) -> torch .LongTensor :
130- sample = self ._build_sample (tokens )
135+ tokenized_sentences = [self .encode (sentence ) for sentence in sentences ]
136+ return [hypos [0 ] for hypos in self .generate (tokenized_sentences , score_reference = True , ** kwargs )]
137+
138+ def generate (
139+ self ,
140+ tokenized_sentences : List [torch .LongTensor ],
141+ beam : int = 5 ,
142+ verbose : bool = False ,
143+ skip_invalid_size_inputs = False ,
144+ ** kwargs
145+ ) -> List [List [Dict [str , torch .Tensor ]]]:
146+ if torch .is_tensor (tokenized_sentences ) and tokenized_sentences .dim () == 1 :
147+ return self .generate (
148+ tokenized_sentences .unsqueeze (0 ), beam = beam , verbose = verbose , ** kwargs
149+ )[0 ]
131150
132151 # build generator using current args as well as any kwargs
133152 gen_args = copy .copy (self .args )
@@ -136,30 +155,35 @@ def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = Fals
136155 setattr (gen_args , k , v )
137156 generator = self .task .build_generator (gen_args )
138157
139- translations = self .task .inference_step (generator , self .models , sample )
140-
141- if verbose :
142- src_str_with_unk = self .string (tokens )
143- print ('S\t {}' .format (src_str_with_unk ))
158+ results = []
159+ for batch in self ._build_batches (tokenized_sentences , skip_invalid_size_inputs ):
160+ batch = utils .apply_to_sample (lambda t : t .to (self .device ), batch )
161+ translations = self .task .inference_step (generator , self .models , batch )
162+ for id , hypos in zip (batch ["id" ].tolist (), translations ):
163+ results .append ((id , hypos ))
144164
145- def getarg ( name , default ):
146- return getattr ( gen_args , name , getattr ( self . args , name , default ))
165+ # sort output to match input order
166+ outputs = [ hypos for _ , hypos in sorted ( results , key = lambda x : x [ 0 ])]
147167
148- # Process top predictions
149- hypos = translations [0 ]
150168 if verbose :
151- for hypo in hypos :
152- hypo_str = self .decode (hypo ['tokens' ])
153- print ('H\t {}\t {}' .format (hypo ['score' ], hypo_str ))
154- print ('P\t {}' .format (
155- ' ' .join (map (lambda x : '{:.4f}' .format (x ), hypo ['positional_scores' ].tolist ()))
156- ))
157- if hypo ['alignment' ] is not None and getarg ('print_alignment' , False ):
158- print ('A\t {}' .format (
159- ' ' .join (map (lambda x : str (utils .item (x )), hypo ['alignment' ].int ().cpu ()))
160- ))
161169
162- return hypos
170+ def getarg (name , default ):
171+ return getattr (gen_args , name , getattr (self .args , name , default ))
172+
173+ for source_tokens , target_hypotheses in zip (tokenized_sentences , outputs ):
174+ src_str_with_unk = self .string (source_tokens )
175+ print ('S\t {}' .format (src_str_with_unk ))
176+ for hypo in target_hypotheses :
177+ hypo_str = self .decode (hypo ['tokens' ])
178+ print ('H\t {}\t {}' .format (hypo ['score' ], hypo_str ))
179+ print ('P\t {}' .format (
180+ ' ' .join (map (lambda x : '{:.4f}' .format (x ), hypo ['positional_scores' ].tolist ()))
181+ ))
182+ if hypo ['alignment' ] is not None and getarg ('print_alignment' , False ):
183+ print ('A\t {}' .format (
184+ ' ' .join (map (lambda x : str (utils .item (x )), hypo ['alignment' ].int ().cpu ()))
185+ ))
186+ return outputs
163187
164188 def encode (self , sentence : str ) -> torch .LongTensor :
165189 sentence = self .tokenize (sentence )
@@ -197,15 +221,18 @@ def binarize(self, sentence: str) -> torch.LongTensor:
197221 def string (self , tokens : torch .LongTensor ) -> str :
198222 return self .tgt_dict .string (tokens )
199223
200- def _build_sample (self , src_tokens : torch .LongTensor ):
201- assert torch .is_tensor (src_tokens )
202- dataset = self .task .build_dataset_for_inference ([src_tokens ], [src_tokens .numel ()])
203- sample = dataset .collater ([dataset [0 ]])
204- sample = utils .apply_to_sample (
205- lambda tensor : tensor .to (self .device ),
206- sample
207- )
208- return sample
224+ def _build_batches (
225+ self , tokens : List [List [int ]], skip_invalid_size_inputs : bool
226+ ) -> Iterator [Dict [str , Any ]]:
227+ lengths = torch .LongTensor ([t .numel () for t in tokens ])
228+ batch_iterator = self .task .get_batch_iterator (
229+ dataset = self .task .build_dataset_for_inference (tokens , lengths ),
230+ max_tokens = self .args .max_tokens ,
231+ max_sentences = self .args .max_sentences ,
232+ max_positions = self .max_positions ,
233+ ignore_invalid_inputs = skip_invalid_size_inputs ,
234+ ).next_epoch_itr (shuffle = False )
235+ return batch_iterator
209236
210237
211238class BPEHubInterface (object ):
0 commit comments