Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 91 additions & 101 deletions nmt_without_attention/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,15 @@ RNN 的原始结构用一个向量来存储隐状态,然而这种结构的 RNN
在 PaddlePaddle 中,双向编码器可以很方便地调用相关 APIs 实现:

```python
#### Encoder
src_word_id = paddle.layer.data(
name='source_language_word',
type=paddle.data_type.integer_value_sequence(source_dict_dim))

# source embedding
src_embedding = paddle.layer.embedding(
input=src_word_id, size=word_vector_dim)
# use bidirectional_gru

# # bidierctional GRU as encoder
encoded_vector = paddle.networks.bidirectional_gru(
input=src_embedding,
size=encoder_size,
Expand All @@ -84,19 +85,17 @@ encoded_vector = paddle.networks.bidirectional_gru(


### 无注意力机制的解码器
PaddleBook中[机器翻译](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.cn.md)的相关章节中,已介绍了带注意力机制(Attention Mechanism)的 Encoder-Decoder 结构,本例则介绍的是不带注意力机制的 Encoder-Decoder 结构。关于注意力机制,读者可进一步参考 PaddleBook 和参考文献\[[3](#参考文献)]。
-PaddleBook中[机器翻译](https://github.com/PaddlePaddle/book/blob/develop/08.machine_translation/README.cn.md)的相关章节中,已介绍了带注意力机制(Attention Mechanism)的 Encoder-Decoder 结构,本例则介绍的是不带注意力机制的 Encoder-Decoder 结构。关于注意力机制,读者可进一步参考 PaddleBook 和参考文献\[[3](#参考文献)]。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"本例则介绍的是" -> "本例介绍的则是"? 感觉会好一点

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


对于流行的RNN单元,PaddlePaddle 已有很好的实现均可直接调用。如果希望在 RNN 每一个时间步实现某些自定义操作,可使用 PaddlePaddle 中的`recurrent_layer_group`。首先,自定义单步逻辑函数,再利用函数 `recurrent_group()` 循环调用单步逻辑函数处理整个序列。本例中的无注意力机制的解码器便是使用`recurrent_layer_group`来实现,其中,单步逻辑函数`gru_decoder_without_attention()`相关代码如下:

```python
#### Decoder
# the initialization state for decoder GRU
encoder_last = paddle.layer.last_seq(input=encoded_vector)
encoder_last_projected = paddle.layer.mixed(
size=decoder_size,
act=paddle.activation.Tanh(),
input=paddle.layer.full_matrix_projection(input=encoder_last))
encoder_last_projected = paddle.layer.fc(
size=decoder_size, act=paddle.activation.Tanh(), input=encoder_last)

# gru step
# the step function for decoder GRU
def gru_decoder_without_attention(enc_vec, current_word):
'''
Step function for gru decoder
Expand All @@ -106,33 +105,29 @@ def gru_decoder_without_attention(enc_vec, current_word):
:type current_word: layer object
'''
decoder_mem = paddle.layer.memory(
name='gru_decoder',
size=decoder_size,
boot_layer=encoder_last_projected)
name="gru_decoder",
size=decoder_size,
boot_layer=encoder_last_projected)

context = paddle.layer.last_seq(input=enc_vec)

decoder_inputs = paddle.layer.mixed(
size=decoder_size * 3,
input=[
paddle.layer.full_matrix_projection(input=context),
paddle.layer.full_matrix_projection(input=current_word)
])
decoder_inputs = paddle.layer.fc(
size=decoder_size * 3, input=[context, current_word])

gru_step = paddle.layer.gru_step(
name='gru_decoder',
name="gru_decoder",
act=paddle.activation.Tanh(),
gate_act=paddle.activation.Sigmoid(),
input=decoder_inputs,
output_mem=decoder_mem,
size=decoder_size)

out = paddle.layer.mixed(
out = paddle.layer.fc(
size=target_dict_dim,
bias_attr=True,
act=paddle.activation.Softmax(),
input=paddle.layer.full_matrix_projection(input=gru_step))
return out
input=gru_step)
return out
```

在模型训练和测试阶段,解码器的行为有很大的不同:
Expand All @@ -143,34 +138,14 @@ def gru_decoder_without_attention(enc_vec, current_word):
训练和生成的逻辑分别实现在如下的`if-else`条件分支中:

```python
decoder_group_name = "decoder_group"
group_input1 = paddle.layer.StaticInput(input=encoded_vector, is_seq=True)
group_input1 = paddle.layer.StaticInput(input=encoded_vector)
group_inputs = [group_input1]
if not generating:
trg_embedding = paddle.layer.embedding(
input=paddle.layer.data(
name='target_language_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name='_target_language_embedding'))
group_inputs.append(trg_embedding)

decoder = paddle.layer.recurrent_group(
name=decoder_group_name,
step=gru_decoder_without_attention,
input=group_inputs)

lbl = paddle.layer.data(
name='target_language_next_word',
type=paddle.data_type.integer_value_sequence(target_dict_dim))
cost = paddle.layer.classification_cost(input=decoder, label=lbl)

return cost
else:

decoder_group_name = "decoder_group"
if is_generating:
trg_embedding = paddle.layer.GeneratedInput(
size=target_dict_dim,
embedding_name='_target_language_embedding',
embedding_name="_target_language_embedding",
embedding_size=word_vector_dim)
group_inputs.append(trg_embedding)

Expand All @@ -184,6 +159,26 @@ else:
max_length=max_length)

return beam_gen
else:
trg_embedding = paddle.layer.embedding(
input=paddle.layer.data(
name="target_language_word",
type=paddle.data_type.integer_value_sequence(target_dict_dim)),
size=word_vector_dim,
param_attr=paddle.attr.ParamAttr(name="_target_language_embedding"))
group_inputs.append(trg_embedding)

decoder = paddle.layer.recurrent_group(
name=decoder_group_name,
step=gru_decoder_without_attention,
input=group_inputs)

lbl = paddle.layer.data(
name="target_language_next_word",
type=paddle.data_type.integer_value_sequence(target_dict_dim))
cost = paddle.layer.classification_cost(input=decoder, label=lbl)

return cost
```

## 数据准备
Expand All @@ -198,22 +193,25 @@ else:

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to reorganize line 189 because we don't have params --train and --generate now.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

**a) 由网络定义,解析网络结构,初始化模型参数**

```
```python
# initialize model
cost = seq2seq_net(source_dict_dim, target_dict_dim)
parameters = paddle.parameters.create(cost)
```

**b) 设定训练过程中的优化策略、定义训练数据读取 `reader`**

```
# define optimize method and trainer
```python
# define optimization method
optimizer = paddle.optimizer.RMSProp(
learning_rate=1e-3,
gradient_clipping_threshold=10.0,
regularization=paddle.optimizer.L2Regularization(rate=8e-4))

# define the trainer instance
trainer = paddle.trainer.SGD(
cost=cost, parameters=parameters, update_equation=optimizer)

# define data reader
wmt14_reader = paddle.batch(
paddle.reader.shuffle(
Expand All @@ -223,40 +221,39 @@ wmt14_reader = paddle.batch(

**c) 定义事件句柄,打印训练中间结果、保存模型快照**

```
# define event_handler callback
```python
# define the event_handler callback
def event_handler(event):
if isinstance(event, paddle.event.EndIteration):
if event.batch_id % 100 == 0 and event.batch_id > 0:
with gzip.open('models/nmt_without_att_params_batch_%d.tar.gz' %
event.batch_id, 'w') as f:
if not event.batch_id % 100 and event.batch_id:
with gzip.open(
os.path.join(save_path,
"nmt_without_att_%05d_batch_%05d.tar.gz" %
event.pass_id, event.batch_id), "w") as f:
parameters.to_tar(f)

if event.batch_id % 10 == 0:
print "\nPass %d, Batch %d, Cost%f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics)
else:
sys.stdout.write('.')
sys.stdout.flush()
if event.batch_id and not event.batch_id % 10:
logger.info("Pass %d, Batch %d, Cost %f, %s" % (
event.pass_id, event.batch_id, event.cost, event.metrics))
```

**d) 开始训练**

```
```python
# start to train
trainer.train(
reader=wmt14_reader, event_handler=event_handler, num_passes=2)
```

启动模型训练的十分简单,只需在命令行窗口中执行

```
python nmt_without_attention_v2.py --train
```bash
python train.py
```

输出样例为

```
```text
Pass 0, Batch 0, Cost 267.674663, {'classification_error_evaluator': 1.0}
.........
Pass 0, Batch 10, Cost 172.892294, {'classification_error_evaluator': 0.953895092010498}
Expand All @@ -274,7 +271,7 @@ Pass 0, Batch 40, Cost 168.170543, {'classification_error_evaluator': 0.83481836

**a) 加载测试样本**

```
```python
# load data samples for generation
gen_creator = paddle.dataset.wmt14.gen(source_dict_dim)
gen_data = []
Expand All @@ -284,7 +281,7 @@ for item in gen_creator():

**b) 初始化模型,执行`infer()`为每个输入样本生成`beam search`的翻译结果**

```
```python
beam_gen = seq2seq_net(source_dict_dim, target_dict_dim, True)
with gzip.open(init_models_path) as f:
parameters = paddle.parameters.Parameters.from_tar(f)
Expand All @@ -298,51 +295,44 @@ beam_result = paddle.infer(

**c) 加载源语言和目标语言词典,将`id`序列表示的句子转化成原语言并输出结果**

```
# get the dictionary
src_dict, trg_dict = paddle.dataset.wmt14.get_dict(source_dict_dim)

# the delimited element of generated sequences is -1,
# the first element of each generated sequence is the sequence length
seq_list = []
seq = []
for w in beam_result[1]:
if w != -1:
seq.append(w)
else:
seq_list.append(' '.join([trg_dict.get(w) for w in seq[1:]]))
seq = []

prob = beam_result[0]
for i in xrange(len(gen_data)):
print "\n*******************************************************\n"
print "src:", ' '.join([src_dict.get(w) for w in gen_data[i][0]]), "\n"
```python
beam_result = inferer.infer(input=test_batch, field=["prob", "id"])

gen_sen_idx = np.where(beam_result[1] == -1)[0]
assert len(gen_sen_idx) == len(test_batch) * beam_size

start_pos, end_pos = 1, 0
for i, sample in enumerate(test_batch):
print(" ".join([
src_dict[w] for w in sample[0][1:-1]
])) # skip the start and ending mark when print the source sentence
for j in xrange(beam_size):
print "prob = %f:" % (prob[i][j]), seq_list[i * beam_size + j]
end_pos = gen_sen_idx[i * beam_size + j]
print("%.4f\t%s" % (beam_result[0][i][j], " ".join(
trg_dict[w] for w in beam_result[1][start_pos:end_pos])))
start_pos = end_pos + 2
print("\n")
```

模型测试的执行与模型训练类似,只需执行

```bash
python generate.py
```
python nmt_without_attention_v2.py --generate
```
则自动为测试数据生成了对应的翻译结果。
设置beam search的宽度为3,输入某个法文句子

```
src: <s> Elles connaissent leur entreprise mieux que personne . <e>
```
设置beam search的宽度为3,输入为一个法文句子,则自动为测试数据生成对应的翻译结果,输出格式如下:

其对应的英文翻译结果为
```text
Elles connaissent leur entreprise mieux que personne .
-3.754819 They know their business better than anyone . <e>
-4.445528 They know their businesses better than anyone . <e>
-5.026885 They know their business better than anybody . <e>

```
prob = -3.754819: They know their business better than anyone . <e>
prob = -4.445528: They know their businesses better than anyone . <e>
prob = -5.026885: They know their business better than anybody . <e>
```

* `prob`表示生成句子的得分,随之其后则是翻译生成的句子;
* `<s>` 表示句子的开始,`<e>`表示一个句子的结束,如果出现了在词典中未包含的词,则用`<unk>`替代。
- 第一行为输入的源语言句子。
- 第二 ~ `beam_size + 1` 行是柱搜索生成的 `beam_size` 条翻译结果
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

第二 ~ beam_size + 1 行 -> 第二至第beam_size + 1 行,
在文字中少用符号是不是有好点?anyway,个人偏好。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done.

- 一行之内以“\t”分隔为两列,第一列是句子的log 概率,第二列是翻译结果的文本。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

一行之内 -> 相同行的输出

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- `<s>` 表示句子的开始,`<e>`表示一个句子的结束,如果出现了在词典中未包含的词,则用`<unk>`替代。
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<s> -> 符号<s><unk>->  符号<unk>

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


至此,我们在 PaddlePaddle 上实现了一个初步的机器翻译模型。我们可以看到,PaddlePaddle 提供了灵活丰富的API供大家选择和使用,使得我们能够很方便完成各种复杂网络的配置。机器翻译本身也是个快速发展的领域,各种新方法新思想在不断涌现。在学习完本例后,读者若有兴趣和余力,可基于 PaddlePaddle 平台实现更为复杂、性能更优的机器翻译模型。

Expand Down
Loading