-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Add inference program for Transformer. #727
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,6 +15,23 @@ class TrainTaskConfig(object): | |
| # the params for learning rate scheduling | ||
| warmup_steps = 4000 | ||
|
|
||
| # the directory for saving inference models | ||
| model_dir = "transformer_model" | ||
|
||
|
|
||
|
|
||
| class InferTaskConfig(object): | ||
| use_gpu = False | ||
| # number of sequences contained in a mini-batch | ||
|
||
| batch_size = 1 | ||
|
|
||
| # the params for beam search | ||
|
||
| beam_size = 5 | ||
| max_length = 30 | ||
| n_best = 1 | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please comment
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
|
|
||
| # the directory for loading inference model | ||
| model_path = "transformer_model/pass_1.infer.model" | ||
|
||
|
|
||
|
|
||
| class ModelHyperParams(object): | ||
| # Dictionary size for source and target language. This model directly uses | ||
|
|
@@ -33,6 +50,11 @@ class ModelHyperParams(object): | |
| # index for <pad> token in target language. | ||
| trg_pad_idx = trg_vocab_size | ||
|
|
||
| # index for <bos> token | ||
| bos_idx = 0 | ||
| # index for <eos> token | ||
| eos_idx = 1 | ||
|
|
||
| # position value corresponding to the <pad> token. | ||
| pos_pad_idx = 0 | ||
|
|
||
|
|
@@ -64,14 +86,21 @@ class ModelHyperParams(object): | |
| "src_pos_enc_table", | ||
| "trg_pos_enc_table", ) | ||
|
|
||
| # Names of all data layers listed in order. | ||
| input_data_names = ( | ||
| # Names of all data layers in encoder listed in order. | ||
| encoder_input_data_names = ( | ||
| "src_word", | ||
| "src_pos", | ||
| "src_slf_attn_bias", ) | ||
|
|
||
| # Names of all data layers in decoder listed in order. | ||
| decoder_input_data_names = ( | ||
| "trg_word", | ||
| "trg_pos", | ||
| "src_slf_attn_bias", | ||
| "trg_slf_attn_bias", | ||
| "trg_src_attn_bias", | ||
| "enc_output", ) | ||
|
|
||
| # Names of label related data layers listed in order. | ||
| label_data_names = ( | ||
| "lbl_word", | ||
| "lbl_weight", ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,220 @@ | ||
| import numpy as np | ||
|
|
||
| import paddle.v2 as paddle | ||
| import paddle.fluid as fluid | ||
|
|
||
| import model | ||
| from model import wrap_encoder as encoder | ||
| from model import wrap_decoder as decoder | ||
| from config import InferTaskConfig, ModelHyperParams, \ | ||
| encoder_input_data_names, decoder_input_data_names | ||
| from train import pad_batch_data | ||
|
|
||
|
|
||
| def translate_batch(exe, src_words, encoder, enc_in_names, enc_out_names, | ||
| decoder, dec_in_names, dec_out_names, beam_size, max_length, | ||
| n_best, batch_size, n_head, src_pad_idx, trg_pad_idx, | ||
| bos_idx, eos_idx): | ||
| """ | ||
| Run the encoder program once and run the decoder program multiple times to | ||
| implement beam search externally. | ||
| """ | ||
| # Prepare data for encoder and run the encoder. | ||
| enc_in_data = pad_batch_data( | ||
| src_words, | ||
| src_pad_idx, | ||
| n_head, | ||
| is_target=False, | ||
| return_pos=True, | ||
| return_attn_bias=True, | ||
| return_max_len=True) | ||
| enc_output = exe.run(encoder, | ||
| feed=dict(zip(enc_in_names, enc_in_data)), | ||
| fetch_list=enc_out_names)[0] | ||
|
|
||
| # Beam Search. | ||
| # To store the beam info. | ||
| scores = np.zeros((batch_size, beam_size), dtype="float32") | ||
| prev_branchs = [[]] * batch_size | ||
| next_ids = [[]] * batch_size | ||
| # Use beam_map to map the instance idx in batch to beam idx, since the | ||
| # size of feeded batch is changing. | ||
| beam_map = range(batch_size) | ||
|
|
||
| def beam_backtrace(prev_branchs, next_ids, n_best=beam_size, add_bos=True): | ||
| """ | ||
| Decode and select n_best sequences for one instance by backtrace. | ||
| """ | ||
| seqs = [] | ||
| for i in range(n_best): | ||
| k = i | ||
| seq = [] | ||
| for j in range(len(prev_branchs) - 1, -1, -1): | ||
| seq.append(next_ids[j][k]) | ||
| k = prev_branchs[j][k] | ||
| seq = seq[::-1] | ||
| seq = [bos_idx] + seq if add_bos else seq | ||
| seqs.append(seq) | ||
| return seqs | ||
|
|
||
| def init_dec_in_data(batch_size, beam_size, enc_in_data, enc_output): | ||
| """ | ||
| Initialize the input data for decoder. | ||
| """ | ||
| trg_words = np.array( | ||
| [[bos_idx]] * batch_size * beam_size, dtype="int64") | ||
| trg_pos = np.array([[1]] * batch_size * beam_size, dtype="int64") | ||
| src_max_length, src_slf_attn_bias, trg_max_len = enc_in_data[ | ||
| -1], enc_in_data[-2], 1 | ||
| trg_src_attn_bias = np.tile( | ||
| src_slf_attn_bias[:, :, ::src_max_length, :], | ||
| [beam_size, 1, trg_max_len, 1]) | ||
| enc_output = np.tile(enc_output, [beam_size, 1, 1]) | ||
| # No need for trg_slf_attn_bias because of no paddings. | ||
|
||
| return trg_words, trg_pos, None, trg_src_attn_bias, enc_output | ||
|
|
||
| def update_dec_in_data(dec_in_data, next_ids, active_beams): | ||
| """ | ||
| Update the input data of decoder mainly by slicing from the previous | ||
| input data and dropping the finished instance beams. | ||
| """ | ||
| trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output = dec_in_data | ||
| trg_words = np.array( | ||
| [ | ||
| beam_backtrace( | ||
| prev_branchs[beam_idx], next_ids[beam_idx], add_bos=True) | ||
| for beam_idx in active_beams | ||
| ], | ||
| dtype="int64") | ||
| trg_words = trg_words.reshape([-1, 1]) | ||
| trg_pos = np.array( | ||
| [range(1, len(next_ids[0]) + 2)] * len(active_beams) * beam_size, | ||
| dtype="int64").reshape([-1, 1]) | ||
| active_beams_indice = ( | ||
| (np.array(active_beams) * beam_size)[:, np.newaxis] + | ||
| np.array(range(beam_size))[np.newaxis, :]).flatten() | ||
| trg_src_attn_bias = np.tile(trg_src_attn_bias[ | ||
| active_beams_indice, :, ::trg_src_attn_bias.shape[2], :], | ||
| [1, 1, len(next_ids[0]) + 1, 1]) | ||
| enc_output = enc_output[active_beams_indice, :, :] | ||
| return trg_words, trg_pos, trg_slf_attn_bias, trg_src_attn_bias, enc_output | ||
|
|
||
| dec_in_data = init_dec_in_data(batch_size, beam_size, enc_in_data, | ||
| enc_output) | ||
| for i in range(max_length): | ||
| predict_all = exe.run(decoder, | ||
| feed=dict( | ||
| filter(lambda item: item[1] is not None, | ||
| zip(dec_in_names, dec_in_data))), | ||
| fetch_list=dec_out_names)[0] | ||
| predict_all = np.log(predict_all) | ||
| predict_all = ( | ||
| predict_all.reshape( | ||
| [len(beam_map) * beam_size, i + 1, -1])[:, -1, :] + | ||
| scores[beam_map].reshape([len(beam_map) * beam_size, -1])).reshape( | ||
| [len(beam_map), beam_size, -1]) | ||
| active_beams = [] | ||
| for inst_idx, beam_idx in enumerate(beam_map): | ||
| predict = (predict_all[inst_idx, :, :] | ||
| if i != 0 else predict_all[inst_idx, 0, :]).flatten() | ||
| top_k_indice = np.argpartition(predict, -beam_size)[-beam_size:] | ||
| top_scores_ids = top_k_indice[np.argsort(predict[top_k_indice])[:: | ||
| -1]] | ||
| top_scores = predict[top_scores_ids] | ||
| scores[beam_idx] = top_scores | ||
| prev_branchs[beam_idx].append(top_scores_ids / | ||
| predict_all.shape[-1]) | ||
| next_ids[beam_idx].append(top_scores_ids % predict_all.shape[-1]) | ||
| if next_ids[beam_idx][-1][0] != eos_idx: | ||
| active_beams.append(beam_idx) | ||
| beam_map = active_beams | ||
| if len(beam_map) == 0: | ||
| break | ||
| dec_in_data = update_dec_in_data(dec_in_data, next_ids, active_beams) | ||
|
|
||
| # Decode beams and select n_best sequences for each instance by backtrace. | ||
| seqs = [beam_backtrace(prev_branchs[beam_idx], next_ids[beam_idx], n_best)] | ||
|
|
||
| return seqs, scores[:, :n_best].tolist() | ||
|
|
||
|
|
||
| def main(): | ||
| place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace() | ||
| exe = fluid.Executor(place) | ||
| # The current program desc is coupled with batch_size and the only | ||
| # supported batch size is 1 currently. | ||
| encoder_program = fluid.Program() | ||
| model.batch_size = InferTaskConfig.batch_size | ||
| with fluid.program_guard(main_program=encoder_program): | ||
| enc_output = encoder( | ||
| ModelHyperParams.src_vocab_size + 1, | ||
| ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, | ||
| ModelHyperParams.n_head, ModelHyperParams.d_key, | ||
| ModelHyperParams.d_value, ModelHyperParams.d_model, | ||
| ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, | ||
| ModelHyperParams.src_pad_idx, ModelHyperParams.pos_pad_idx) | ||
|
|
||
| model.batch_size = InferTaskConfig.batch_size * InferTaskConfig.beam_size | ||
| decoder_program = fluid.Program() | ||
| with fluid.program_guard(main_program=decoder_program): | ||
| predict = decoder( | ||
| ModelHyperParams.trg_vocab_size + 1, | ||
| ModelHyperParams.max_length + 1, ModelHyperParams.n_layer, | ||
| ModelHyperParams.n_head, ModelHyperParams.d_key, | ||
| ModelHyperParams.d_value, ModelHyperParams.d_model, | ||
| ModelHyperParams.d_inner_hid, ModelHyperParams.dropout, | ||
| ModelHyperParams.trg_pad_idx, ModelHyperParams.pos_pad_idx) | ||
|
|
||
| # Load model parameters of encoder and decoder separately from the saved | ||
| # transformer model. | ||
| encoder_var_names = [] | ||
| for op in encoder_program.block(0).ops: | ||
| encoder_var_names += op.input_arg_names | ||
| encoder_param_names = filter( | ||
| lambda var_name: isinstance(encoder_program.block(0).var(var_name), | ||
| fluid.framework.Parameter), | ||
| encoder_var_names) | ||
| encoder_params = map(encoder_program.block(0).var, encoder_param_names) | ||
| decoder_var_names = [] | ||
| for op in decoder_program.block(0).ops: | ||
| decoder_var_names += op.input_arg_names | ||
| decoder_param_names = filter( | ||
| lambda var_name: isinstance(decoder_program.block(0).var(var_name), | ||
| fluid.framework.Parameter), | ||
| decoder_var_names) | ||
| decoder_params = map(decoder_program.block(0).var, decoder_param_names) | ||
| fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=encoder_params) | ||
| fluid.io.load_vars(exe, InferTaskConfig.model_path, vars=decoder_params) | ||
|
|
||
| # This is used here to set dropout to the test mode. | ||
| encoder_program = fluid.io.get_inference_program( | ||
| target_vars=[enc_output], main_program=encoder_program) | ||
| decoder_program = fluid.io.get_inference_program( | ||
| target_vars=[predict], main_program=decoder_program) | ||
|
|
||
| test_data = paddle.batch( | ||
| paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size, | ||
| ModelHyperParams.trg_vocab_size), | ||
| batch_size=InferTaskConfig.batch_size) | ||
|
|
||
| trg_idx2word = paddle.dataset.wmt16.get_dict( | ||
| "de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True) | ||
|
|
||
| for batch_id, data in enumerate(test_data()): | ||
| batch_seqs, batch_scores = translate_batch( | ||
| exe, [item[0] for item in data], encoder_program, | ||
| encoder_input_data_names, [enc_output.name], decoder_program, | ||
| decoder_input_data_names, [predict.name], InferTaskConfig.beam_size, | ||
| InferTaskConfig.max_length, InferTaskConfig.n_best, | ||
| InferTaskConfig.batch_size, ModelHyperParams.n_head, | ||
| ModelHyperParams.src_pad_idx, ModelHyperParams.trg_pad_idx, | ||
| ModelHyperParams.bos_idx, ModelHyperParams.eos_idx) | ||
| for i in range(len(batch_seqs)): | ||
| seqs = batch_seqs[i] | ||
| scores = batch_scores[i] | ||
| for seq in seqs: | ||
| print(" ".join([trg_idx2word[idx] for idx in seq])) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for saving inference models --> for saving trained models
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.