Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@ class InferTaskConfig(object):
# the number of decoded sentences to output.
n_best = 1

# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False

# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"

Expand All @@ -56,6 +61,8 @@ class ModelHyperParams(object):
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2

# position value corresponding to the <pad> token.
pos_pad_idx = 0
Expand Down
84 changes: 67 additions & 17 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,25 @@
from train import pad_batch_data


def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
decoder, dec_in_names, dec_out_names, beam_size, max_length,
n_best, batch_size, n_head, src_pad_idx, trg_pad_idx,
bos_idx, eos_idx):
def translate_batch(exe,
src_words,
encoder,
enc_in_names,
enc_out_names,
decoder,
dec_in_names,
dec_out_names,
beam_size,
max_length,
n_best,
batch_size,
n_head,
src_pad_idx,
trg_pad_idx,
bos_idx,
eos_idx,
unk_idx,
output_unk=True):
"""
Run the encoder program once and run the decoder program multiple times to
implement beam search externally.
Expand Down Expand Up @@ -48,7 +63,7 @@ def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names,
# size of feeded batch is changing.
beam_map = range(batch_size)

def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
def beam_backtrace(prev_branchs, next_ids, n_best=beam_size):
"""
Decode and select n_best sequences for one instance by backtrace.
"""
Expand All @@ -60,7 +75,8 @@ def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True):
seq.append(next_ids[j][k])
k = prev_branchs[j][k]
seq = seq[::-1]
seq = [bos_idx] + seq if add_bos else seq
# Add the <bos>, since next_ids don't include the <bos>.
seq = [bos_idx] + seq
seqs.append(seq)
return seqs

Expand Down Expand Up @@ -114,8 +130,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
trg_cur_len = len(next_ids[0]) + 1 # include the <bos>
trg_words = np.array(
[
beam_backtrace(
prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True)
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx])
for beam_idx in active_beams
],
dtype="int64")
Expand Down Expand Up @@ -167,6 +182,8 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
predict_all = (predict_all + scores[beam_map].reshape(
[len(beam_map) * beam_size, -1])).reshape(
[len(beam_map), beam_size, -1])
if not output_unk: # To exclude the <unk> token.
predict_all[:, :, unk_idx] = -1e9
active_beams = []
for inst_idx, beam_idx in enumerate(beam_map):
predict = (predict_all[inst_idx, :, :]
Expand All @@ -187,7 +204,10 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams):
dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams)

# Decode beams and select n_best sequences for each instance by backtrace.
seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)]
seqs = [
beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)
for beam_idx in range(batch_size)
]

return seqs, scores[:, :n_best].tolist()

Expand Down Expand Up @@ -254,17 +274,47 @@ def main():
trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)

def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
eos_idx=ModelHyperParams.eos_idx,
output_bos=InferTaskConfig.output_bos,
output_eos=InferTaskConfig.output_eos):
"""
Post-process the beam-search decoded sequence. Truncate from the first
<eos> and remove the <bos> and <eos> tokens currently.
"""
eos_pos = len(seq) - 1
for i, idx in enumerate(seq):
if idx == eos_idx:
eos_pos = i
break
seq = seq[:eos_pos + 1]
return filter(
lambda idx: (output_bos or idx != bos_idx) and \
(output_eos or idx != eos_idx),
seq)

for batch_id, data in enumerate(test_data()):
batch_seqs, batch_scores = translate_batch(
exe, [item[0] for item in data], encoder_program,
encoder_input_data_names, [enc_output.name], decoder_program,
decoder_input_data_names, [predict.name], InferTaskConfig.beam_size,
InferTaskConfig.max_length, InferTaskConfig.n_best,
len(data), ModelHyperParams.n_head, ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx, ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx)
exe, [item[0] for item in data],
encoder_program,
encoder_input_data_names, [enc_output.name],
decoder_program,
decoder_input_data_names, [predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
InferTaskConfig.n_best,
len(data),
ModelHyperParams.n_head,
ModelHyperParams.src_pad_idx,
ModelHyperParams.trg_pad_idx,
ModelHyperParams.bos_idx,
ModelHyperParams.eos_idx,
ModelHyperParams.unk_idx,
output_unk=InferTaskConfig.output_unk)
for i in range(len(batch_seqs)):
seqs = batch_seqs[i]
# Post-process the beam-search decoded sequences.
seqs = map(post_process_seq, batch_seqs[i])
scores = batch_scores[i]
for seq in seqs:
print(" ".join([trg_idx2word[idx] for idx in seq]))
Expand Down