33# This source code is licensed under the MIT license found in the
44# LICENSE file in the root directory of this source tree.
55
6+ from collections import namedtuple
7+
68import torch
9+
710from fairseq import utils
8- from fairseq .models .levenshtein_transformer import LevenshteinTransformerModel
9- from fairseq .models .model_utils import script_skip_tensor_list , skip_tensors as _skip
10- from fairseq .models .nonautoregressive_ensembles import EnsembleLevT
11+
12+
13+ DecoderOut = namedtuple ('IterativeRefinementDecoderOut' , [
14+ 'output_tokens' ,
15+ 'output_scores' ,
16+ 'attn' ,
17+ 'step' ,
18+ 'max_step' ,
19+ ])
1120
1221
1322class IterativeRefinementGenerator (object ):
@@ -88,6 +97,8 @@ def generate_batched_itr(
8897
8998 @torch .no_grad ()
9099 def generate (self , models , sample , prefix_tokens = None ):
100+ from fairseq .models .levenshtein_transformer import LevenshteinTransformerModel
101+ from fairseq .models .nonautoregressive_ensembles import EnsembleLevT
91102
92103 if len (models ) == 1 :
93104 # Keep this for other NAT models for which we have yet to implement ensemble wrappers. Later delete this.
@@ -110,7 +121,7 @@ def generate(self, models, sample, prefix_tokens=None):
110121
111122 # initialize buffers (very model specific, with length prediction or not)
112123 prev_decoder_out = model .initialize_output_tokens (encoder_out , src_tokens )
113- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
124+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
114125
115126 finalized = [[] for _ in range (bsz )]
116127
@@ -150,8 +161,10 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
150161 "max_ratio" : self .max_ratio ,
151162 "decoding_format" : self .decoding_format ,
152163 }
153- prev_decoder_out [3 ] = step
154- prev_decoder_out [4 ] = self .max_iter + 1
164+ prev_decoder_out = prev_decoder_out ._replace (
165+ step = step ,
166+ max_step = self .max_iter + 1 ,
167+ )
155168
156169 decoder_out = model .forward_decoder (
157170 prev_decoder_out , encoder_out , ** decoder_options
@@ -160,24 +173,26 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
160173 if self .adaptive :
161174 # terminate if there is a loop
162175 terminated , out_tokens , out_scores , out_attn = is_a_loop (
163- prev_output_tokens , decoder_out [0 ], decoder_out [1 ], decoder_out [2 ]
176+ prev_output_tokens , decoder_out .output_tokens , decoder_out .output_scores , decoder_out .attn
177+ )
178+ decoder_out = decoder_out ._replace (
179+ output_tokens = out_tokens ,
180+ output_scores = out_scores ,
181+ attn = out_attn ,
164182 )
165- decoder_out [0 ] = out_tokens
166- decoder_out [1 ] = out_scores
167- decoder_out [2 ] = out_attn
168183
169184 else :
170- terminated = decoder_out [ 0 ]. new_zeros (decoder_out [ 0 ] .size (0 )).bool ()
185+ terminated = decoder_out . output_tokens . new_zeros (decoder_out . output_tokens .size (0 )).bool ()
171186
172187 if step == self .max_iter : # reach last iteration, terminate
173188 terminated .fill_ (1 )
174189
175190 # collect finalized sentences
176191 finalized_idxs = sent_idxs [terminated ]
177- finalized_tokens = decoder_out [ 0 ] [terminated ]
178- finalized_scores = decoder_out [ 1 ] [terminated ]
192+ finalized_tokens = decoder_out . output_tokens [terminated ]
193+ finalized_scores = decoder_out . output_scores [terminated ]
179194 finalized_attn = (
180- None if decoder_out [ 2 ] is None else decoder_out [ 2 ] [terminated ]
195+ None if decoder_out . attn is None else decoder_out . attn [terminated ]
181196 )
182197
183198 for i in range (finalized_idxs .size (0 )):
@@ -194,10 +209,15 @@ def finalized_hypos(step, prev_out_token, prev_out_score, prev_out_attn):
194209 break
195210
196211 # for next step
197- prev_decoder_out = _skip (decoder_out , ~ terminated )
198- encoder_out = script_skip_tensor_list (encoder_out , ~ terminated )
199- sent_idxs = _skip (sent_idxs , ~ terminated )
212+ not_terminated = ~ terminated
213+ prev_decoder_out = decoder_out ._replace (
214+ output_tokens = decoder_out .output_tokens [not_terminated ],
215+ output_scores = decoder_out .output_scores [not_terminated ],
216+ attn = decoder_out .attn [not_terminated ] if decoder_out .attn is not None else None ,
217+ )
218+ encoder_out = model .encoder .reorder_encoder_out (encoder_out , not_terminated .nonzero ().squeeze ())
219+ sent_idxs = sent_idxs [not_terminated ]
200220
201- prev_output_tokens = prev_decoder_out [ 0 ] .clone ()
221+ prev_output_tokens = prev_decoder_out . output_tokens .clone ()
202222
203223 return finalized
0 commit comments