Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 115 additions & 32 deletions fluid/neural_machine_translation/transformer/config.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,50 @@
class TrainTaskConfig(object):
use_gpu = False
use_gpu = True
# the epoch number to train.
pass_num = 2

pass_num = 30
# the number of sequences contained in a mini-batch.
batch_size = 64

batch_size = 32
# the hyper parameters for Adam optimizer.
learning_rate = 0.001
# This static learning_rate will be multiplied to the LearningRateScheduler
# derived learning rate the to get the final learning rate.
learning_rate = 1
beta1 = 0.9
beta2 = 0.98
eps = 1e-9

# the parameters for learning rate scheduling.
warmup_steps = 4000

# the flag indicating to use average loss or sum loss when training.
use_avg_cost = False

use_avg_cost = True
# the weight used to mix up the ground-truth distribution and the fixed
# uniform distribution in label smoothing when training.
# Set this as zero if label smoothing is not wanted.
label_smooth_eps = 0.1
# the directory for saving trained models.
model_dir = "trained_models"
# the directory for saving checkpoints.
ckpt_dir = "trained_ckpts"
# the directory for loading checkpoint.
# If provided, continue training from the checkpoint.
ckpt_path = None
# the parameter to initialize the learning rate scheduler.
# It should be provided if use checkpoints, since the checkpoint doesn't
# include the training step counter currently.
start_step = 0


class InferTaskConfig(object):
use_gpu = False
use_gpu = True
# the number of examples in one run for sequence generation.
batch_size = 10

# the parameters for beam search.
beam_size = 5
max_length = 30
# the number of decoded sentences to output.
n_best = 1

# the flags indicating whether to output the special tokens.
output_bos = False
output_eos = False
output_unk = False

# the directory for loading the trained model.
model_path = "trained_models/pass_1.infer.model"

Expand All @@ -47,30 +54,24 @@ class ModelHyperParams(object):
# <unk> token has alreay been added. As for the <pad> token, any token
# included in dict can be used to pad, since the paddings' loss will be
# masked out and make no effect on parameter gradients.

# size of source word dictionary.
src_vocab_size = 10000

# size of target word dictionay
trg_vocab_size = 10000

# index for <bos> token
bos_idx = 0
# index for <eos> token
eos_idx = 1
# index for <unk> token
unk_idx = 2

# max length of sequences.
# The size of position encoding table should at least plus 1, since the
# sinusoid position encoding starts from 1 and 0 can be used as the padding
# token for position encoding.
max_length = 50

# the dimension for word embeddings, which is also the last dimension of
# the input and output of multi-head attention, position-wise feed-forward
# networks, encoder and decoder.

d_model = 512
# size of the hidden layer in position-wise feed-forward networks.
d_inner_hid = 1024
Expand All @@ -86,34 +87,116 @@ class ModelHyperParams(object):
dropout = 0.1


def merge_cfg_from_list(cfg_list, g_cfgs):
"""
Set the above global configurations using the cfg_list.
"""
assert len(cfg_list) % 2 == 0
for key, value in zip(cfg_list[0::2], cfg_list[1::2]):
for g_cfg in g_cfgs:
if hasattr(g_cfg, key):
try:
value = eval(value)
except SyntaxError: # for file path
pass
setattr(g_cfg, key, value)
break


# Here list the data shapes and data types of all inputs.
# The shapes here act as placeholder and are set to pass the infer-shape in
# compile time.
input_descs = {
# The actual data shape of src_word is:
# [batch_size * max_src_len_in_batch, 1]
"src_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of src_pos is:
# [batch_size * max_src_len_in_batch, 1]
"src_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings in the
# encoder.
# The actual data shape of src_slf_attn_bias is:
# [batch_size, n_head, max_src_len_in_batch, max_src_len_in_batch]
"src_slf_attn_bias":
[(1, ModelHyperParams.n_head, (ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
"src_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"src_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"src_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# The actual data shape of trg_word is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# The actual data shape of trg_pos is:
# [batch_size * max_trg_len_in_batch, 1]
"trg_pos": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to remove attention weights on paddings and
# subsequent words in the decoder.
# The actual data shape of trg_slf_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_trg_len_in_batch]
"trg_slf_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This input is used to remove attention weights on paddings of the source
# input in the encoder-decoder attention.
# The actual data shape of trg_src_attn_bias is:
# [batch_size, n_head, max_trg_len_in_batch, max_src_len_in_batch]
"trg_src_attn_bias": [(1, ModelHyperParams.n_head,
(ModelHyperParams.max_length + 1),
(ModelHyperParams.max_length + 1)), "float32"],
# This shape input is used to reshape the output of embedding layer.
"trg_data_shape": [(3L, ), "int32"],
# This shape input is used to reshape before softmax in self attention.
"trg_slf_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in self attention.
"trg_slf_attn_post_softmax_shape": [(4L, ), "int32"],
# This shape input is used to reshape before softmax in encoder-decoder
# attention.
"trg_src_attn_pre_softmax_shape": [(2L, ), "int32"],
# This shape input is used to reshape after softmax in encoder-decoder
# attention.
"trg_src_attn_post_softmax_shape": [(4L, ), "int32"],
# This input is used in independent decoder program for inference.
# The actual data shape of enc_output is:
# [batch_size, max_src_len_in_batch, d_model]
"enc_output": [(1, (ModelHyperParams.max_length + 1),
ModelHyperParams.d_model), "float32"],
# The actual data shape of label_word is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_word": [(1 * (ModelHyperParams.max_length + 1), 1L), "int64"],
# This input is used to mask out the loss of paddding tokens.
# The actual data shape of label_weight is:
# [batch_size * max_trg_len_in_batch, 1]
"lbl_weight": [(1 * (ModelHyperParams.max_length + 1), 1L), "float32"],
}

# Names of position encoding table which will be initialized externally.
pos_enc_param_names = (
"src_pos_enc_table",
"trg_pos_enc_table", )

# Names of all data layers in encoder listed in order.
encoder_input_data_names = (
# separated inputs for different usages.
encoder_data_input_fields = (
"src_word",
"src_pos",
"src_slf_attn_bias",
"src_slf_attn_bias", )
encoder_util_input_fields = (
"src_data_shape",
"src_slf_attn_pre_softmax_shape",
"src_slf_attn_post_softmax_shape", )

# Names of all data layers in decoder listed in order.
decoder_input_data_names = (
decoder_data_input_fields = (
"trg_word",
"trg_pos",
"trg_slf_attn_bias",
"trg_src_attn_bias",
"enc_output", )
decoder_util_input_fields = (
"trg_data_shape",
"trg_slf_attn_pre_softmax_shape",
"trg_slf_attn_post_softmax_shape",
"trg_src_attn_pre_softmax_shape",
"trg_src_attn_post_softmax_shape",
"enc_output", )

# Names of label related data layers listed in order.
label_data_names = (
"trg_src_attn_post_softmax_shape", )
label_data_input_fields = (
"lbl_word",
"lbl_weight", )
92 changes: 79 additions & 13 deletions fluid/neural_machine_translation/transformer/infer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import argparse
import numpy as np

import paddle
Expand All @@ -6,9 +7,62 @@
import model
from model import wrap_encoder as encoder
from model import wrap_decoder as decoder
from config import InferTaskConfig, ModelHyperParams, \
encoder_input_data_names, decoder_input_data_names
from config import *
from train import pad_batch_data
import reader


def parse_args():
parser = argparse.ArgumentParser("Training for Transformer.")
parser.add_argument(
"--src_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of source language.")
parser.add_argument(
"--trg_vocab_fpath",
type=str,
required=True,
help="The path of vocabulary file of target language.")
parser.add_argument(
"--test_file_pattern",
type=str,
required=True,
help="The pattern to match test data files.")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="The number of examples in one run for sequence generation.")
parser.add_argument(
"--pool_size",
type=int,
default=10000,
help="The buffer size to pool data.")
parser.add_argument(
"--special_token",
type=str,
default=["<s>", "<e>", "<unk>"],
nargs=3,
help="The <bos>, <eos> and <unk> tokens in the dictionary.")
parser.add_argument(
'opts',
help='See config.py for all options',
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
# Append args related to dict
src_dict = reader.DataReader.load_dict(args.src_vocab_fpath)
trg_dict = reader.DataReader.load_dict(args.trg_vocab_fpath)
dict_args = [
"src_vocab_size", str(len(src_dict)), "trg_vocab_size",
str(len(trg_dict)), "bos_idx", str(src_dict[args.special_token[0]]),
"eos_idx", str(src_dict[args.special_token[1]]), "unk_idx",
str(src_dict[args.special_token[2]])
]
merge_cfg_from_list(args.opts + dict_args,
[InferTaskConfig, ModelHyperParams])
return args


def translate_batch(exe,
Expand Down Expand Up @@ -243,7 +297,7 @@ def update_dec_in_data(dec_in_data, next_ids, active_beams, beam_inst_map):
return seqs, scores[:, :n_best].tolist()


def main():
def infer(args):
place = fluid.CUDAPlace(0) if InferTaskConfig.use_gpu else fluid.CPUPlace()
exe = fluid.Executor(place)

Expand Down Expand Up @@ -292,13 +346,23 @@ def main():
decoder_program = fluid.io.get_inference_program(
target_vars=[predict], main_program=decoder_program)

test_data = paddle.batch(
paddle.dataset.wmt16.test(ModelHyperParams.src_vocab_size,
ModelHyperParams.trg_vocab_size),
batch_size=InferTaskConfig.batch_size)
test_data = reader.DataReader(
src_vocab_fpath=args.src_vocab_fpath,
trg_vocab_fpath=args.trg_vocab_fpath,
fpattern=args.test_file_pattern,
batch_size=args.batch_size,
use_token_batch=False,
pool_size=args.pool_size,
sort_type=reader.SortType.NONE,
shuffle=False,
shuffle_batch=False,
start_mark=args.special_token[0],
end_mark=args.special_token[1],
unk_mark=args.special_token[2],
clip_last_batch=False)

trg_idx2word = paddle.dataset.wmt16.get_dict(
"de", dict_size=ModelHyperParams.trg_vocab_size, reverse=True)
trg_idx2word = test_data.load_dict(
dict_path=args.trg_vocab_fpath, reverse=True)

def post_process_seq(seq,
bos_idx=ModelHyperParams.bos_idx,
Expand All @@ -320,15 +384,16 @@ def post_process_seq(seq,
(output_eos or idx != eos_idx),
seq)

for batch_id, data in enumerate(test_data()):
for batch_id, data in enumerate(test_data.batch_generator()):
batch_seqs, batch_scores = translate_batch(
exe,
[item[0] for item in data],
encoder_program,
encoder_input_data_names,
encoder_data_input_fields + encoder_util_input_fields,
[enc_output.name],
decoder_program,
decoder_input_data_names,
decoder_data_input_fields[:-1] + decoder_util_input_fields +
(decoder_data_input_fields[-1], ),
[predict.name],
InferTaskConfig.beam_size,
InferTaskConfig.max_length,
Expand All @@ -351,4 +416,5 @@ def post_process_seq(seq,


if __name__ == "__main__":
main()
args = parse_args()
infer(args)
Loading