Skip to content
This repository was archived by the owner on Mar 20, 2026. It is now read-only.

Commit e46b924

Browse files
Xing Zhoufacebook-github-bot
authored andcommitted
Nucleus (top-P) sampling (#710)
Summary: Implement Nucleus (top-P) sampling: sample among the smallest set of elements whose cumulative probability mass exceeds p. To test it: python generate.py ~myleott/data/data-bin/wmt17_zh_en_full/ --path ~myleott/zh_en/model.pt --remove-bpe --nbest 5 --beam 5 --sampling --sampling-topp 0.3 Pull Request resolved: fairinternal/fairseq-py#710 Test Plan: python generate.py ~myleott/data/data-bin/wmt17_zh_en_full/ --path ~myleott/zh_en/model.pt --remove-bpe --nbest 5 --beam 5 --sampling --sampling-topp 0.3 python tests/test_sequence_generator.py python tests/test_binaries.py Reviewed By: myleott Differential Revision: D16286688 Pulled By: xingz9 fbshipit-source-id: 1776d21e17c4532a3d24ac75bb7e75da9acad58f
1 parent 473389a commit e46b924

6 files changed

Lines changed: 242 additions & 44 deletions

File tree

fairseq/options.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -472,6 +472,8 @@ def add_generation_args(parser):
472472
help='sample hypotheses instead of using beam search')
473473
group.add_argument('--sampling-topk', default=-1, type=int, metavar='PS',
474474
help='sample from top K likely next words instead of all words')
475+
group.add_argument('--sampling-topp', default=-1.0, type=float, metavar='PS',
476+
help='sample from the smallest set whose cumulative probability mass exceeds p for next words')
475477
group.add_argument('--temperature', default=1., type=float, metavar='N',
476478
help='temperature for generation')
477479
group.add_argument('--diverse-beam-groups', default=-1, type=int, metavar='N',

fairseq/search.py

Lines changed: 58 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,54 @@ def step(self, step, lprobs, scores):
168168

169169
class Sampling(Search):
170170

171-
def __init__(self, tgt_dict, sampling_topk=-1):
171+
def __init__(self, tgt_dict, sampling_topk=-1, sampling_topp=-1.0):
172172
super().__init__(tgt_dict)
173173
self.sampling_topk = sampling_topk
174+
self.sampling_topp = sampling_topp
175+
176+
def _sample_topp(self, lprobs):
177+
"""Sample among the smallest set of elements whose cumulative probability mass exceeds p.
178+
179+
See `"The Curious Case of Neural Text Degeneration"
180+
(Holtzman et al., 2019) <https://arxiv.org/abs/1904.09751>`_.
181+
182+
Args:
183+
lprobs: (bsz x input_beam_size x vocab_size)
184+
the model's log-probabilities over the vocabulary at the current step
185+
186+
Return: A tuple of (trimed_probs, truncated_indices) where:
187+
trimed_probs: (bsz x input_beam_size x ?)
188+
the model's probabilities over the elements selected to sample from. The
189+
width of the third dimension is determined by top-P.
190+
truncated_indices: (bsz x input_beam_size x ?)
191+
the indices of the chosen elements.
192+
"""
193+
probs = lprobs.exp_()
194+
195+
# sort the last dimension (vocab dimension) in descending order
196+
sorted_probs, sorted_indices = probs.sort(descending=True)
197+
198+
# compute a mask to indicate the words to be included in the top-P set.
199+
cumsum_probs = sorted_probs.cumsum(dim=2)
200+
mask = cumsum_probs.lt(self.sampling_topp)
201+
202+
# note that mask was computed by 'lt'. One more word needs to be included
203+
# so that the cumulative probability mass can exceed p.
204+
cumsum_mask = mask.cumsum(dim=2)
205+
last_included = cumsum_mask[:, :, :1]
206+
mask = mask.scatter_(2, last_included, 1)
207+
208+
# truncate unnecessary dims.
209+
max_dim = last_included.max()
210+
truncated_mask = mask[:, :, :max_dim + 1]
211+
truncated_probs = sorted_probs[:, :, :max_dim + 1]
212+
truncated_indices = sorted_indices[:, :, :max_dim + 1]
213+
214+
# trim the words that are not in top-P by setting their probabilities
215+
# to 0, so that they would not be sampled later.
216+
trim_mask = 1 - truncated_mask
217+
trimed_probs = truncated_probs.masked_fill_(trim_mask, 0)
218+
return trimed_probs, truncated_indices
174219

175220
def step(self, step, lprobs, scores):
176221
super()._init_buffers(lprobs)
@@ -185,12 +230,17 @@ def step(self, step, lprobs, scores):
185230
assert self.pad <= 1, 'sampling assumes the first two symbols can be ignored'
186231
lprobs_nopad = lprobs[:, :, 2:]
187232

188-
# only sample from top-k candidates
189-
if self.sampling_topk > 0:
190-
lprobs_nopad, topk_indices = lprobs_nopad.topk(self.sampling_topk)
233+
if self.sampling_topp > 0:
234+
# only sample from the smallest set of words whose cumulative probability mass exceeds p
235+
probs_nopad, top_indices = self._sample_topp(lprobs_nopad)
236+
elif self.sampling_topk > 0:
237+
# only sample from top-k candidates
238+
lprobs_nopad, top_indices = lprobs_nopad.topk(self.sampling_topk)
239+
probs_nopad = lprobs_nopad.exp_()
240+
else:
241+
probs_nopad = lprobs_nopad.exp_()
191242

192243
# sample
193-
probs_nopad = lprobs_nopad.exp_()
194244
if step == 0:
195245
self.indices_buf = torch.multinomial(
196246
probs_nopad.view(bsz, -1),
@@ -219,10 +269,10 @@ def step(self, step, lprobs, scores):
219269
)
220270
self.scores_buf = self.scores_buf.log_().view(bsz, -1)
221271

222-
# remap indices if using top-k sampling
223-
if self.sampling_topk > 0:
272+
# remap indices if using top-k or top-P sampling
273+
if self.sampling_topk > 0 or self.sampling_topp > 0:
224274
self.indices_buf = torch.gather(
225-
topk_indices.expand(bsz, beam_size, -1),
275+
top_indices.expand(bsz, beam_size, -1),
226276
dim=2,
227277
index=self.indices_buf.unsqueeze(-1),
228278
).squeeze(2)

fairseq/sequence_generator.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(
2828
retain_dropout=False,
2929
sampling=False,
3030
sampling_topk=-1,
31+
sampling_topp=-1.0,
3132
temperature=1.,
3233
diverse_beam_groups=-1,
3334
diverse_beam_strength=0.5,
@@ -58,6 +59,9 @@ def __init__(
5859
(default: False)
5960
sampling_topk (int, optional): only sample among the top-k choices
6061
at each step (default: -1)
62+
sampling_topp (float, optional): only sample among the smallest set
63+
of words whose cumulative probability mass exceeds p
64+
at each step (default: -1.0)
6165
temperature (float, optional): temperature, where values
6266
>1.0 produce more uniform samples and values <1.0 produce
6367
sharper samples (default: 1.0)
@@ -86,10 +90,11 @@ def __init__(
8690
self.no_repeat_ngram_size = no_repeat_ngram_size
8791

8892
assert sampling_topk < 0 or sampling, '--sampling-topk requires --sampling'
93+
assert sampling_topp < 0 or sampling, '--sampling-topp requires --sampling'
8994
assert temperature > 0, '--temperature must be greater than 0'
9095

9196
if sampling:
92-
self.search = search.Sampling(tgt_dict, sampling_topk)
97+
self.search = search.Sampling(tgt_dict, sampling_topk, sampling_topp)
9398
elif diverse_beam_groups > 0:
9499
self.search = search.DiverseBeamSearch(tgt_dict, diverse_beam_groups, diverse_beam_strength)
95100
elif match_source_len:

fairseq/tasks/fairseq_task.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ def build_generator(self, args):
201201
unk_penalty=getattr(args, 'unkpen', 0),
202202
sampling=getattr(args, 'sampling', False),
203203
sampling_topk=getattr(args, 'sampling_topk', -1),
204+
sampling_topp=getattr(args, 'sampling_topp', -1.0),
204205
temperature=getattr(args, 'temperature', 1.),
205206
diverse_beam_groups=getattr(args, 'diverse_beam_groups', -1),
206207
diverse_beam_strength=getattr(args, 'diverse_beam_strength', 0.5),

tests/test_binaries.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,12 @@ def test_generation(self):
104104
'--beam', '2',
105105
'--nbest', '2',
106106
])
107+
generate_main(data_dir, [
108+
'--sampling',
109+
'--sampling-topp', '0.2',
110+
'--beam', '2',
111+
'--nbest', '2',
112+
])
107113
generate_main(data_dir, ['--prefix-size', '2'])
108114

109115
def test_lstm(self):

tests/test_sequence_generator.py

Lines changed: 169 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,30 @@
1515
import tests.utils as test_utils
1616

1717

18-
class TestSequenceGenerator(unittest.TestCase):
18+
class TestSequenceGeneratorBase(unittest.TestCase):
19+
20+
def assertHypoTokens(self, hypo, tokens):
21+
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
22+
23+
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
24+
pos_scores = torch.FloatTensor(pos_probs).log()
25+
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
26+
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
27+
score = pos_scores.sum()
28+
if normalized:
29+
score /= pos_scores.numel()**lenpen
30+
self.assertLess(abs(score - hypo['score']), 1e-6)
31+
32+
def assertAlmostEqual(self, t1, t2):
33+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
34+
self.assertLess((t1 - t2).abs().max(), 1e-4)
35+
36+
def assertTensorEqual(self, t1, t2):
37+
self.assertEqual(t1.size(), t2.size(), "size mismatch")
38+
self.assertEqual(t1.ne(t2).long().sum(), 0)
39+
40+
41+
class TestSequenceGenerator(TestSequenceGeneratorBase):
1942

2043
def setUp(self):
2144
self.tgt_dict, self.w1, self.w2, src_tokens, src_lengths, self.model = (
@@ -133,28 +156,8 @@ def test_no_stop_early(self):
133156
self.assertHypoTokens(hypos[1][1], [w1, w2, w1, eos])
134157
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.4, 1.0])
135158

136-
def assertHypoTokens(self, hypo, tokens):
137-
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
138159

139-
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
140-
pos_scores = torch.FloatTensor(pos_probs).log()
141-
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
142-
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
143-
score = pos_scores.sum()
144-
if normalized:
145-
score /= pos_scores.numel()**lenpen
146-
self.assertLess(abs(score - hypo['score']), 1e-6)
147-
148-
def assertAlmostEqual(self, t1, t2):
149-
self.assertEqual(t1.size(), t2.size(), "size mismatch")
150-
self.assertLess((t1 - t2).abs().max(), 1e-4)
151-
152-
def assertTensorEqual(self, t1, t2):
153-
self.assertEqual(t1.size(), t2.size(), "size mismatch")
154-
self.assertEqual(t1.ne(t2).long().sum(), 0)
155-
156-
157-
class TestDiverseBeamSearch(unittest.TestCase):
160+
class TestDiverseBeamSearch(TestSequenceGeneratorBase):
158161

159162
def setUp(self):
160163
# construct dummy dictionary
@@ -232,25 +235,156 @@ def test_diverse_beam_search(self):
232235
self.assertHypoTokens(hypos[1][1], [w1, w2, eos])
233236
self.assertHypoScore(hypos[1][1], [0.7, 0.4, 0.9])
234237

235-
def assertHypoTokens(self, hypo, tokens):
236-
self.assertTensorEqual(hypo['tokens'], torch.LongTensor(tokens))
237238

238-
def assertHypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
239+
class TestTopPSamplingSearch(TestSequenceGeneratorBase):
240+
241+
def setUp(self):
242+
# construct dummy dictionary
243+
d = test_utils.dummy_dictionary(vocab_size=2)
244+
self.assertEqual(d.pad(), 1)
245+
self.assertEqual(d.eos(), 2)
246+
self.assertEqual(d.unk(), 3)
247+
self.eos = d.eos()
248+
self.w1 = 4
249+
self.w2 = 5
250+
251+
# construct source data
252+
self.src_tokens = torch.LongTensor([
253+
[self.w1, self.w2, self.eos],
254+
[self.w1, self.w2, self.eos],
255+
])
256+
self.src_lengths = torch.LongTensor([2, 2])
257+
258+
args = argparse.Namespace()
259+
unk = 0.
260+
# The minimal probability of top 2 tokens.
261+
self.min_top2_prob = 0.75
262+
# The minimal probability of the top 1 token.
263+
self.min_top1_prob = 0.4
264+
265+
w1_prob = self.min_top1_prob
266+
w2_prob = self.min_top2_prob - self.min_top1_prob
267+
eos_prob = 1 - self.min_top2_prob
268+
269+
args.beam_probs = [
270+
# step 0:
271+
torch.FloatTensor([
272+
# eos w1 w2
273+
[0.0, unk, 1.0, 0.0],
274+
[0.0, unk, 1.0, 0.0],
275+
[0.0, unk, 1.0, 0.0],
276+
[0.0, unk, 1.0, 0.0],
277+
]),
278+
# step 1:
279+
torch.FloatTensor([
280+
# eos w1 w2
281+
[eos_prob, unk, w1_prob, w2_prob],
282+
[eos_prob, unk, w1_prob, w2_prob],
283+
[eos_prob, unk, w1_prob, w2_prob],
284+
[eos_prob, unk, w1_prob, w2_prob],
285+
]),
286+
# step 2:
287+
torch.FloatTensor([
288+
# eos w1 w2
289+
[1.0, unk, 0.0, 0.0],
290+
[1.0, unk, 0.0, 0.0],
291+
[1.0, unk, 0.0, 0.0],
292+
[1.0, unk, 0.0, 0.0],
293+
]),
294+
]
295+
296+
task = test_utils.TestTranslationTask.setup_task(args, d, d)
297+
self.model = task.build_model(args)
298+
self.tgt_dict = task.target_dictionary
299+
300+
def test_topp_sampling_search_low_prob(self):
301+
# Given a prob low enough to top-P sampling, we expect only the top
302+
# 1 token to be sampled, which always results in the same output.
303+
low_sampling_topp = self.min_top1_prob/2.0
304+
generator = SequenceGenerator(
305+
self.tgt_dict, beam_size=2, sampling=True,
306+
sampling_topp=low_sampling_topp
307+
)
308+
sample = {
309+
'net_input': {
310+
'src_tokens': self.src_tokens,
311+
'src_lengths': self.src_lengths
312+
}
313+
}
314+
hypos = generator.generate([self.model], sample)
315+
eos, w1 = self.eos, self.w1
316+
# sentence 1, beam 1
317+
self.assertHypoTokens(hypos[0][0], [w1, w1, eos])
318+
self.assertHypoScore(hypos[0][0], [1.0, 0.4, 1.0])
319+
# sentence 1, beam 2
320+
self.assertHypoTokens(hypos[0][1], [w1, w1, eos])
321+
self.assertHypoScore(hypos[0][1], [1.0, 0.4, 1.0])
322+
# sentence 2, beam 1
323+
self.assertHypoTokens(hypos[1][0], [w1, w1, eos])
324+
self.assertHypoScore(hypos[1][0], [1.0, 0.4, 1.0])
325+
# sentence 2, beam 2
326+
self.assertHypoTokens(hypos[1][1], [w1, w1, eos])
327+
self.assertHypoScore(hypos[1][1], [1.0, 0.4, 1.0])
328+
329+
def test_topp_sampling_search_high_prob(self):
330+
# Given a prob high enough to top-P sampling, any of the top 2
331+
# tokens could be sampled. This can cause different outputs.
332+
high_sampling_topp = (self.min_top1_prob+self.min_top2_prob)/2.0
333+
generator = SequenceGenerator(
334+
self.tgt_dict, beam_size=2, sampling=True,
335+
sampling_topp=high_sampling_topp
336+
)
337+
sample = {
338+
'net_input': {
339+
'src_tokens': self.src_tokens,
340+
'src_lengths': self.src_lengths
341+
}
342+
}
343+
hypos = generator.generate([self.model], sample)
344+
eos, w1, w2 = self.eos, self.w1, self.w2
345+
# sentence 1, beam 1
346+
self.assertTrue(self.hypoTokens(hypos[0][0], [w1, w1, eos]) or
347+
self.hypoTokens(hypos[0][0], [w1, w2, eos]))
348+
self.assertTrue(self.hypoScore(hypos[0][0], [1.0, 0.4, 1.0]) or
349+
self.hypoScore(hypos[0][0], [1.0, 0.35, 1.0]))
350+
351+
# sentence 1, beam 2
352+
self.assertTrue(self.hypoTokens(hypos[0][1], [w1, w1, eos]) or
353+
self.hypoTokens(hypos[0][1], [w1, w2, eos]))
354+
self.assertTrue(self.hypoScore(hypos[0][1], [1.0, 0.4, 1.0]) or
355+
self.hypoScore(hypos[0][1], [1.0, 0.35, 1.0]))
356+
357+
# sentence 2, beam 1
358+
self.assertTrue(self.hypoTokens(hypos[1][0], [w1, w1, eos]) or
359+
self.hypoTokens(hypos[1][0], [w1, w2, eos]))
360+
self.assertTrue(self.hypoScore(hypos[1][0], [1.0, 0.4, 1.0]) or
361+
self.hypoScore(hypos[1][0], [1.0, 0.35, 1.0]))
362+
363+
# sentence 2, beam 2
364+
self.assertTrue(self.hypoTokens(hypos[1][1], [w1, w1, eos]) or
365+
self.hypoTokens(hypos[1][1], [w1, w2, eos]))
366+
self.assertTrue(self.hypoScore(hypos[1][1], [1.0, 0.4, 1.0]) or
367+
self.hypoScore(hypos[1][1], [1.0, 0.35, 1.0]))
368+
369+
def hypoTokens(self, hypo, tokens):
370+
return self.tensorEqual(hypo['tokens'], torch.LongTensor(tokens))
371+
372+
def hypoScore(self, hypo, pos_probs, normalized=True, lenpen=1.):
239373
pos_scores = torch.FloatTensor(pos_probs).log()
240-
self.assertAlmostEqual(hypo['positional_scores'], pos_scores)
241-
self.assertEqual(pos_scores.numel(), hypo['tokens'].numel())
374+
if not self.almostEqual(hypo['positional_scores'], pos_scores):
375+
return False
376+
if pos_scores.numel() != hypo['tokens'].numel():
377+
return False
242378
score = pos_scores.sum()
243379
if normalized:
244-
score /= pos_scores.numel()**lenpen
245-
self.assertLess(abs(score - hypo['score']), 1e-6)
380+
score /= pos_scores.numel() ** lenpen
381+
return abs(score - hypo['score']) < 1e-6
246382

247-
def assertAlmostEqual(self, t1, t2):
248-
self.assertEqual(t1.size(), t2.size(), "size mismatch")
249-
self.assertLess((t1 - t2).abs().max(), 1e-4)
383+
def almostEqual(self, t1, t2):
384+
return t1.size() == t2.size() and (t1 - t2).abs().max() < 1e-4
250385

251-
def assertTensorEqual(self, t1, t2):
252-
self.assertEqual(t1.size(), t2.size(), "size mismatch")
253-
self.assertEqual(t1.ne(t2).long().sum(), 0)
386+
def tensorEqual(self, t1, t2):
387+
return t1.size() == t2.size() and t1.ne(t2).long().sum() == 0
254388

255389

256390
if __name__ == '__main__':

0 commit comments

Comments
 (0)