Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit 3b53962

Browse files
sai-prasannafacebook-github-bot
authored andcommitted
Refactor hub interface for batched inference (#1539) (#1539)
Summary: # Before submitting - [x] Was this discussed/approved via a Github issue? (no need for typos, doc improvements) - [x] Did you read the [contributor guideline](https://github.com/pytorch/fairseq/blob/master/CONTRIBUTING.md)? - [x] Did you make sure to update the docs? - [] Did you write any new necessary tests? ## What does this PR do? Fixes #1508. Pull Request resolved: #1539 Pulled By: myleott Differential Revision: D19216104 fbshipit-source-id: 14917c1459b8794eeb74c09a16b9899c366242d2
1 parent 9b19ede commit 3b53962

3 files changed

Lines changed: 80 additions & 41 deletions

File tree

examples/language_model/README.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer_lm.wmt19.en', ...]
2626

2727
# Load an English LM trained on WMT'19 News Crawl data
2828
en_lm = torch.hub.load('pytorch/fairseq', 'transformer_lm.wmt19.en', tokenizer='moses', bpe='fastbpe')
29+
en_lm.eval() # disable dropout
30+
31+
# Move model to GPU
32+
en_lm.cuda()
2933

3034
# Sample from the language model
3135
en_lm.sample('Barack Obama', beam=1, sampling=True, sampling_topk=10, temperature=0.8)

examples/translation/README.md

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,13 +34,21 @@ torch.hub.list('pytorch/fairseq') # [..., 'transformer.wmt16.en-de', ... ]
3434

3535
# Load a transformer trained on WMT'16 En-De
3636
en2de = torch.hub.load('pytorch/fairseq', 'transformer.wmt16.en-de', tokenizer='moses', bpe='subword_nmt')
37+
en2de.eval() # disable dropout
3738

3839
# The underlying model is available under the *models* attribute
3940
assert isinstance(en2de.models[0], fairseq.models.transformer.TransformerModel)
4041

42+
# Move model to GPU for faster translation
43+
en2de.cuda()
44+
4145
# Translate a sentence
4246
en2de.translate('Hello world!')
4347
# 'Hallo Welt!'
48+
49+
# Batched translation
50+
en2de.translate(['Hello world!', 'The cat sat on the mat.'])
51+
# ['Hallo Welt!', 'Die Katze saß auf der Matte.']
4452
```
4553

4654
Loading custom models:

fairseq/hub_utils.py

Lines changed: 68 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import argparse
88
import copy
99
import os
10+
from typing import List, Dict, Iterator, Tuple, Any
1011

1112
import torch
1213
from torch import nn
@@ -106,28 +107,46 @@ def __init__(self, args, task, models):
106107
self.tokenizer = encoders.build_tokenizer(args)
107108
self.bpe = encoders.build_bpe(args)
108109

110+
self.max_positions = utils.resolve_max_positions(
111+
self.task.max_positions(), *[model.max_positions() for model in models]
112+
)
113+
109114
# this is useful for determining the device
110115
self.register_buffer('_float_tensor', torch.tensor([0], dtype=torch.float))
111116

112117
@property
113118
def device(self):
114119
return self._float_tensor.device
115120

116-
def translate(self, sentence: str, beam: int = 5, verbose: bool = False, **kwargs) -> str:
117-
return self.sample(sentence, beam, verbose, **kwargs)
121+
def translate(self, sentences: List[str], beam: int = 5, verbose: bool = False, **kwargs) -> List[str]:
122+
return self.sample(sentences, beam, verbose, **kwargs)
118123

119-
def sample(self, sentence: str, beam: int = 1, verbose: bool = False, **kwargs) -> str:
120-
input = self.encode(sentence)
121-
hypo = self.generate(input, beam, verbose, **kwargs)[0]['tokens']
122-
return self.decode(hypo)
124+
def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> List[str]:
125+
if isinstance(sentences, str):
126+
return self.sample([sentences], beam=beam, verbose=verbose, **kwargs)[0]
127+
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
128+
batched_hypos = self.generate(tokenized_sentences, beam, verbose, **kwargs)
129+
return [self.decode(hypos[0]['tokens']) for hypos in batched_hypos]
123130

124-
def score(self, sentence: str, **kwargs):
131+
def score(self, sentences: List[str], **kwargs):
132+
if isinstance(sentences, str):
133+
return self.score([sentences], **kwargs)[0]
125134
# NOTE: this doesn't support translation tasks currently
126-
input = self.encode(sentence)
127-
return self.generate(input, score_reference=True, **kwargs)[0]
128-
129-
def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:
130-
sample = self._build_sample(tokens)
135+
tokenized_sentences = [self.encode(sentence) for sentence in sentences]
136+
return [hypos[0] for hypos in self.generate(tokenized_sentences, score_reference=True, **kwargs)]
137+
138+
def generate(
139+
self,
140+
tokenized_sentences: List[torch.LongTensor],
141+
beam: int = 5,
142+
verbose: bool = False,
143+
skip_invalid_size_inputs=False,
144+
**kwargs
145+
) -> List[List[Dict[str, torch.Tensor]]]:
146+
if torch.is_tensor(tokenized_sentences) and tokenized_sentences.dim() == 1:
147+
return self.generate(
148+
tokenized_sentences.unsqueeze(0), beam=beam, verbose=verbose, **kwargs
149+
)[0]
131150

132151
# build generator using current args as well as any kwargs
133152
gen_args = copy.copy(self.args)
@@ -136,30 +155,35 @@ def generate(self, tokens: torch.LongTensor, beam: int = 5, verbose: bool = Fals
136155
setattr(gen_args, k, v)
137156
generator = self.task.build_generator(gen_args)
138157

139-
translations = self.task.inference_step(generator, self.models, sample)
140-
141-
if verbose:
142-
src_str_with_unk = self.string(tokens)
143-
print('S\t{}'.format(src_str_with_unk))
158+
results = []
159+
for batch in self._build_batches(tokenized_sentences, skip_invalid_size_inputs):
160+
batch = utils.apply_to_sample(lambda t: t.to(self.device), batch)
161+
translations = self.task.inference_step(generator, self.models, batch)
162+
for id, hypos in zip(batch["id"].tolist(), translations):
163+
results.append((id, hypos))
144164

145-
def getarg(name, default):
146-
return getattr(gen_args, name, getattr(self.args, name, default))
165+
# sort output to match input order
166+
outputs = [hypos for _, hypos in sorted(results, key=lambda x: x[0])]
147167

148-
# Process top predictions
149-
hypos = translations[0]
150168
if verbose:
151-
for hypo in hypos:
152-
hypo_str = self.decode(hypo['tokens'])
153-
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
154-
print('P\t{}'.format(
155-
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
156-
))
157-
if hypo['alignment'] is not None and getarg('print_alignment', False):
158-
print('A\t{}'.format(
159-
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
160-
))
161169

162-
return hypos
170+
def getarg(name, default):
171+
return getattr(gen_args, name, getattr(self.args, name, default))
172+
173+
for source_tokens, target_hypotheses in zip(tokenized_sentences, outputs):
174+
src_str_with_unk = self.string(source_tokens)
175+
print('S\t{}'.format(src_str_with_unk))
176+
for hypo in target_hypotheses:
177+
hypo_str = self.decode(hypo['tokens'])
178+
print('H\t{}\t{}'.format(hypo['score'], hypo_str))
179+
print('P\t{}'.format(
180+
' '.join(map(lambda x: '{:.4f}'.format(x), hypo['positional_scores'].tolist()))
181+
))
182+
if hypo['alignment'] is not None and getarg('print_alignment', False):
183+
print('A\t{}'.format(
184+
' '.join(map(lambda x: str(utils.item(x)), hypo['alignment'].int().cpu()))
185+
))
186+
return outputs
163187

164188
def encode(self, sentence: str) -> torch.LongTensor:
165189
sentence = self.tokenize(sentence)
@@ -197,15 +221,18 @@ def binarize(self, sentence: str) -> torch.LongTensor:
197221
def string(self, tokens: torch.LongTensor) -> str:
198222
return self.tgt_dict.string(tokens)
199223

200-
def _build_sample(self, src_tokens: torch.LongTensor):
201-
assert torch.is_tensor(src_tokens)
202-
dataset = self.task.build_dataset_for_inference([src_tokens], [src_tokens.numel()])
203-
sample = dataset.collater([dataset[0]])
204-
sample = utils.apply_to_sample(
205-
lambda tensor: tensor.to(self.device),
206-
sample
207-
)
208-
return sample
224+
def _build_batches(
225+
self, tokens: List[List[int]], skip_invalid_size_inputs: bool
226+
) -> Iterator[Dict[str, Any]]:
227+
lengths = torch.LongTensor([t.numel() for t in tokens])
228+
batch_iterator = self.task.get_batch_iterator(
229+
dataset=self.task.build_dataset_for_inference(tokens, lengths),
230+
max_tokens=self.args.max_tokens,
231+
max_sentences=self.args.max_sentences,
232+
max_positions=self.max_positions,
233+
ignore_invalid_inputs=skip_invalid_size_inputs,
234+
).next_epoch_itr(shuffle=False)
235+
return batch_iterator
209236

210237

211238
class BPEHubInterface(object):

0 commit comments

Comments
 (0)