diff --git a/CHANGELOG.md b/CHANGELOG.md index fa75b2d2b..8efd7c448 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -91,6 +91,9 @@ To release a new version, please update the changelog as followed: - Add version_info in model.config. (PR #992) - Replace tf.nn.func with tf.nn.func.\_\_name\_\_ in model config. - Add Reinforcement learning tutorials. (PR #995) +- Add RNN layers with simple rnn cell, GRU cell, LSTM cell. (PR #998) +- Update Seq2seq (#998) +- Add Seq2seqLuongAttention model (#998) ### Fixed @@ -100,12 +103,14 @@ To release a new version, please update the changelog as followed: - @Tokarev-TT-33: #995 - @initial-h: #995 - @Officium: #995 +- @ArnoldLIULJ: #998 +- @JingqingZ: #998 + ## [2.0.2] - 2019-6-5 ### Changed - change the format of network config, change related code and files; change layer act (PR #980) -- update Seq2seq (#989) ### Fixed - Fix dynamic model cannot track PRelu weights gradients problem (PR #982) @@ -113,7 +118,6 @@ To release a new version, please update the changelog as followed: ### Contributors - @warshallrho: #980 -- @ArnoldLIULJ: #989 - @1FengL: #982 ## [2.0.1] - 2019-5-17 diff --git a/docs/modules/layers.rst b/docs/modules/layers.rst index 7a70b54dc..f6c86a542 100644 --- a/docs/modules/layers.rst +++ b/docs/modules/layers.rst @@ -80,11 +80,15 @@ Layer list SwitchNorm RNN + SimpleRNN + GRURNN + LSTMRNN BiRNN retrieve_seq_length_op retrieve_seq_length_op2 retrieve_seq_length_op3 + target_mask_op Flatten Reshape @@ -579,6 +583,18 @@ RNN layer """""""""""""""""""""""""" .. autoclass:: RNN +RNN layer with Simple RNN Cell +"""""""""""""""""""""""""""""""""" +.. autoclass:: SimpleRNN + +RNN layer with GRU Cell +"""""""""""""""""""""""""""""""""" +.. autoclass:: GRURNN + +RNN layer with LSTM Cell +"""""""""""""""""""""""""""""""""" +.. autoclass:: LSTMRNN + Bidirectional layer """"""""""""""""""""""""""""""""" .. autoclass:: BiRNN @@ -593,13 +609,16 @@ Compute Sequence length 1 .. autofunction:: retrieve_seq_length_op Compute Sequence length 2 -"""""""""""""""""""""""""" +""""""""""""""""""""""""""""" .. autofunction:: retrieve_seq_length_op2 Compute Sequence length 3 -"""""""""""""""""""""""""" +"""""""""""""""""""""""""""" .. autofunction:: retrieve_seq_length_op3 +Compute mask of the target sequence +""""""""""""""""""""""""""""""""""""""" +.. autofunction:: target_mask_op diff --git a/docs/modules/models.rst b/docs/modules/models.rst index cdfd6ccc6..46b8d7e1b 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -13,6 +13,9 @@ TensorLayer provides many pretrained models, you can easily use the whole or a p VGG19 SqueezeNetV1 MobileNetV1 + Seq2seq + Seq2seqLuongAttention + Base Model ----------- @@ -37,3 +40,14 @@ MobileNetV1 ---------------- .. autofunction:: MobileNetV1 + +Seq2seq +------------------------ + +.. autoclass:: Seq2seq + + +Seq2seq Luong Attention +------------------------ + +.. autoclass:: Seq2seqLuongAttention diff --git a/tensorlayer/layers/recurrent.py b/tensorlayer/layers/recurrent.py index d91288dda..0bd6315ee 100644 --- a/tensorlayer/layers/recurrent.py +++ b/tensorlayer/layers/recurrent.py @@ -7,17 +7,12 @@ from tensorlayer.decorators import deprecated_alias from tensorlayer.layers.core import Layer -# from tensorflow.python.ops import array_ops -# from tensorflow.python.util.tf_inspect import getfullargspec -# from tensorflow.contrib.rnn import stack_bidirectional_dynamic_rnn -# from tensorflow.python.ops.rnn_cell import LSTMStateTuple - -# from tensorlayer.layers.core import LayersConfig -# from tensorlayer.layers.core import TF_GRAPHKEYS_VARIABLES - # TODO: uncomment __all__ = [ 'RNN', + 'SimpleRNN', + 'GRURNN', + 'LSTMRNN', 'BiRNN', # 'ConvRNNCell', # 'BasicConvLSTMCell', @@ -25,8 +20,7 @@ 'retrieve_seq_length_op', 'retrieve_seq_length_op2', 'retrieve_seq_length_op3', - # 'target_mask_op', - # 'Seq2Seq', + 'target_mask_op', ] @@ -222,7 +216,238 @@ def forward(self, inputs, initial_state=None, **kwargs): return outputs -# TODO: write tl.layers.SimpleRNN, tl.layers.GRU, tl.layers.LSTM +class SimpleRNN(RNN): + """ + The :class:`SimpleRNN` class is a fixed length recurrent layer for implementing simple RNN. + + Parameters + ---------- + units: int + Positive integer, the dimension of hidden space. + return_last_output : boolean + Whether return last output or all outputs in a sequence. + - If True, return the last output, "Sequence input and single output" + - If False, return all outputs, "Synced sequence input and output" + - In other word, if you want to stack more RNNs on this layer, set to False + + In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). + By default, `False`. + return_seq_2d : boolean + Only consider this argument when `return_last_output` is `False` + - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. + - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. + + In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). + By default, `False`. + return_last_state: boolean + Whether to return the last state of the RNN cell. The state is a list of Tensor. + For simple RNN, last_state = [last_output] + + - If True, the layer will return outputs and the final state of the cell. + - If False, the layer will return outputs only. + + In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). + By default, `False`. + in_channels: int + Optional, the number of channels of the previous layer which is normally the size of embedding. + If given, the layer will be built when init. + If None, it will be automatically detected when the layer is forwarded for the first time. + name : str + A unique layer name. + `**kwargs`: + Advanced arguments to configure the simple RNN cell. + Please check tf.keras.layers.SimpleRNNCell. + + Examples + -------- + + A simple regression model below. + + >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) + >>> rnn_out, lstm_state = tl.layers.SimpleRNN( + >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the simple rnn cell. + >>> in_channels=embedding_size, + >>> return_last_output=True, return_last_state=True, name='simplernn' + >>> )(inputs) + >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) + >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') + + Notes + ----- + Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. + + """ + + def __init__( + self, + units, + return_last_output=False, + return_seq_2d=False, + return_last_state=True, + in_channels=None, + name=None, # 'simplernn' + **kwargs + ): + super(SimpleRNN, self).__init__( + cell=tf.keras.layers.SimpleRNNCell(units=units, **kwargs), return_last_output=return_last_output, + return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name + ) + + +class GRURNN(RNN): + """ + The :class:`GRURNN` class is a fixed length recurrent layer for implementing RNN with GRU cell. + + Parameters + ---------- + units: int + Positive integer, the dimension of hidden space. + return_last_output : boolean + Whether return last output or all outputs in a sequence. + - If True, return the last output, "Sequence input and single output" + - If False, return all outputs, "Synced sequence input and output" + - In other word, if you want to stack more RNNs on this layer, set to False + + In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). + By default, `False`. + return_seq_2d : boolean + Only consider this argument when `return_last_output` is `False` + - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. + - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. + + In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). + By default, `False`. + return_last_state: boolean + Whether to return the last state of the RNN cell. The state is a list of Tensor. + For GRU, last_state = [last_output] + + - If True, the layer will return outputs and the final state of the cell. + - If False, the layer will return outputs only. + + In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). + By default, `False`. + in_channels: int + Optional, the number of channels of the previous layer which is normally the size of embedding. + If given, the layer will be built when init. + If None, it will be automatically detected when the layer is forwarded for the first time. + name : str + A unique layer name. + `**kwargs`: + Advanced arguments to configure the GRU cell. + Please check tf.keras.layers.GRUCell. + + Examples + -------- + + A simple regression model below. + + >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) + >>> rnn_out, lstm_state = tl.layers.GRURNN( + >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the GRU cell. + >>> in_channels=embedding_size, + >>> return_last_output=True, return_last_state=True, name='grurnn' + >>> )(inputs) + >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) + >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') + + Notes + ----- + Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. + + """ + + def __init__( + self, + units, + return_last_output=False, + return_seq_2d=False, + return_last_state=True, + in_channels=None, + name=None, # 'grurnn' + **kwargs + ): + super(GRURNN, self).__init__( + cell=tf.keras.layers.GRUCell(units=units, **kwargs), return_last_output=return_last_output, + return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name + ) + + +class LSTMRNN(RNN): + """ + The :class:`LSTMRNN` class is a fixed length recurrent layer for implementing RNN with LSTM cell. + + Parameters + ---------- + units: int + Positive integer, the dimension of hidden space. + return_last_output : boolean + Whether return last output or all outputs in a sequence. + - If True, return the last output, "Sequence input and single output" + - If False, return all outputs, "Synced sequence input and output" + - In other word, if you want to stack more RNNs on this layer, set to False + + In a dynamic model, `return_last_output` can be updated when it is called in customised forward(). + By default, `False`. + return_seq_2d : boolean + Only consider this argument when `return_last_output` is `False` + - If True, return 2D Tensor [batch_size * n_steps, n_hidden], for stacking Dense layer after it. + - If False, return 3D Tensor [batch_size, n_steps, n_hidden], for stacking multiple RNN after it. + + In a dynamic model, `return_seq_2d` can be updated when it is called in customised forward(). + By default, `False`. + return_last_state: boolean + Whether to return the last state of the RNN cell. The state is a list of Tensor. + For LSTM, last_state = [last_output, last_cell_state] + + - If True, the layer will return outputs and the final state of the cell. + - If False, the layer will return outputs only. + + In a dynamic model, `return_last_state` can be updated when it is called in customised forward(). + By default, `False`. + in_channels: int + Optional, the number of channels of the previous layer which is normally the size of embedding. + If given, the layer will be built when init. + If None, it will be automatically detected when the layer is forwarded for the first time. + name : str + A unique layer name. + `**kwargs`: + Advanced arguments to configure the LSTM cell. + Please check tf.keras.layers.LSTMCell. + + Examples + -------- + + A simple regression model below. + + >>> inputs = tl.layers.Input([batch_size, num_steps, embedding_size]) + >>> rnn_out, lstm_state = tl.layers.LSTMRNN( + >>> units=hidden_size, dropout=0.1, # both units and dropout are used to configure the LSTM cell. + >>> in_channels=embedding_size, + >>> return_last_output=True, return_last_state=True, name='grurnn' + >>> )(inputs) + >>> outputs = tl.layers.Dense(n_units=1)(rnn_out) + >>> rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]], name='rnn_model') + + Notes + ----- + Input dimension should be rank 3 : [batch_size, n_steps, n_features], if no, please see layer :class:`Reshape`. + + """ + + def __init__( + self, + units, + return_last_output=False, + return_seq_2d=False, + return_last_state=True, + in_channels=None, + name=None, # 'lstmrnn' + **kwargs + ): + super(LSTMRNN, self).__init__( + cell=tf.keras.layers.LSTMCell(units=units, **kwargs), return_last_output=return_last_output, + return_seq_2d=return_seq_2d, return_last_state=return_last_state, in_channels=in_channels, name=name + ) class BiRNN(Layer): @@ -865,203 +1090,64 @@ def retrieve_seq_length_op3(data, pad_val=0): ) -def target_mask_op(data, pad_val=0): # HangSheng: return tensor for mask,if input is tf.string - """Return tensor for mask, if input is ``tf.string``.""" +def target_mask_op(data, pad_val=0): + """ Return the mask of the input sequence data based on the padding values. + + Parameters + ----------- + data : tf.Tensor + A tensor with 2 or 3 dimensions. + pad_val: int, float, string, etc + The value that represent padding. By default, 0. For tf.string, you may use empty string. + + Examples + ----------- + >>> data = [['hello', 'world', '', '', ''], + >>> ['hello', 'world', 'tensorlayer', '', ''], + >>> ['hello', 'world', 'tensorlayer', '2.0', '']] + >>> data = tf.convert_to_tensor(data, dtype=tf.string) + >>> mask = tl.layers.target_mask_op(data, pad_val='') + >>> print(mask) + tf.Tensor( + [[1 1 0 0 0] + [1 1 1 0 0] + [1 1 1 1 0]], shape=(3, 5), dtype=int32) + >>> data = [[[1], [0], [0], [0], [0]], + >>> [[1], [2], [3], [0], [0]], + >>> [[1], [2], [0], [1], [0]]] + >>> data = tf.convert_to_tensor(data, dtype=tf.float32) + >>> mask = tl.layers.target_mask_op(data) + >>> print(mask) + tf.Tensor( + [[1 0 0 0 0] + [1 1 1 0 0] + [1 1 0 1 0]], shape=(3, 5), dtype=int32) + >>> data = [[[0,0],[2,2],[1,2],[1,2],[0,0]], + >>> [[2,3],[2,4],[3,2],[1,0],[0,0]], + >>> [[3,3],[0,1],[5,3],[1,2],[0,0]]] + >>> data = tf.convert_to_tensor(data, dtype=tf.float32) + >>> mask = tl.layers.target_mask_op(data) + >>> print(mask) + tf.Tensor( + [[0 1 1 1 0] + [1 1 1 1 0] + [1 1 1 1 0]], shape=(3, 5), dtype=int32) + """ + + if not isinstance(data, tf.Tensor): + raise AttributeError("target_mask_op: the type of input data should be tf.Tensor but got %s." % type(data)) data_shape_size = data.get_shape().ndims if data_shape_size == 3: return tf.cast(tf.reduce_any(input_tensor=tf.not_equal(data, pad_val), axis=2), dtype=tf.int32) elif data_shape_size == 2: return tf.cast(tf.not_equal(data, pad_val), dtype=tf.int32) elif data_shape_size == 1: - raise ValueError("target_mask_op: data has wrong shape!") + raise ValueError( + "target_mask_op: data_shape %s is not supported. " + "The shape of data should have 2 or 3 dims." % (data.get_shape()) + ) else: - raise ValueError("target_mask_op: handling data_shape_size %s hasn't been implemented!" % (data_shape_size)) - - -class Seq2Seq(Layer): - """ - The :class:`Seq2Seq` class is a simple :class:`DynamicRNNLayer` based Seq2seq layer without using `tl.contrib.seq2seq `__. - See `Model `__ - and `Sequence to Sequence Learning with Neural Networks `__. - - - Please check this example `Chatbot in 200 lines of code `__. - - The Author recommends users to read the source code of :class:`DynamicRNNLayer` and :class:`Seq2Seq`. - - Parameters - ---------- - net_encode_in : :class:`Layer` - Encode sequences, [batch_size, None, n_features]. - net_decode_in : :class:`Layer` - Decode sequences, [batch_size, None, n_features]. - cell_fn : TensorFlow cell function - A TensorFlow core RNN cell - - see `RNN Cells in TensorFlow `__ - - Note TF1.0+ and TF1.0- are different - - cell_init_args : dictionary or None - The arguments for the cell initializer. - n_hidden : int - The number of hidden units in the layer. - initializer : initializer - The initializer for the parameters. - encode_sequence_length : tensor - For encoder sequence length, see :class:`DynamicRNNLayer` . - decode_sequence_length : tensor - For decoder sequence length, see :class:`DynamicRNNLayer` . - initial_state_encode : None or RNN state - If None, `initial_state_encode` is zero state, it can be set by placeholder or other RNN. - initial_state_decode : None or RNN state - If None, `initial_state_decode` is the final state of the RNN encoder, it can be set by placeholder or other RNN. - dropout : tuple of float or int - The input and output keep probability (input_keep_prob, output_keep_prob). - - If one int, input and output keep probability are the same. - - n_layer : int - The number of RNN layers, default is 1. - return_seq_2d : boolean - Only consider this argument when `return_last_output` is `False` - - If True, return 2D Tensor [n_example, 2 * n_hidden], for stacking DenseLayer after it. - - If False, return 3D Tensor [n_example/n_steps, n_steps, 2 * n_hidden], for stacking multiple RNN after it. - - name : str - A unique layer name. - - Attributes - ------------ - outputs : tensor - The output of RNN decoder. - initial_state_encode : tensor or StateTuple - Initial state of RNN encoder. - initial_state_decode : tensor or StateTuple - Initial state of RNN decoder. - final_state_encode : tensor or StateTuple - Final state of RNN encoder. - final_state_decode : tensor or StateTuple - Final state of RNN decoder. - - Notes - -------- - - How to feed data: `Sequence to Sequence Learning with Neural Networks `__ - - input_seqs : ``['how', 'are', 'you', '']`` - - decode_seqs : ``['', 'I', 'am', 'fine', '']`` - - target_seqs : ``['I', 'am', 'fine', '', '']`` - - target_mask : ``[1, 1, 1, 1, 0]`` - - related functions : tl.prepro - - Examples - ---------- - >>> from tensorlayer.layers import * - >>> batch_size = 32 - >>> encode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="encode_seqs") - >>> decode_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="decode_seqs") - >>> target_seqs = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_seqs") - >>> target_mask = tf.placeholder(dtype=tf.int64, shape=[batch_size, None], name="target_mask") # tl.prepro.sequences_get_mask() - >>> with tf.variable_scope("model"): - >>> # for chatbot, you can use the same embedding layer, - >>> # for translation, you may want to use 2 seperated embedding layers - >>> with tf.variable_scope("embedding") as vs: - >>> net_encode = EmbeddingInput( - ... inputs = encode_seqs, - ... vocabulary_size = 10000, - ... embedding_size = 200, - ... name = 'seq_embedding') - >>> vs.reuse_variables() - >>> net_decode = EmbeddingInput( - ... inputs = decode_seqs, - ... vocabulary_size = 10000, - ... embedding_size = 200, - ... name = 'seq_embedding') - >>> net = Seq2Seq(net_encode, net_decode, - ... cell_fn = tf.contrib.rnn.BasicLSTMCell, - ... n_hidden = 200, - ... initializer = tf.random_uniform_initializer(-0.1, 0.1), - ... encode_sequence_length = retrieve_seq_length_op2(encode_seqs), - ... decode_sequence_length = retrieve_seq_length_op2(decode_seqs), - ... initial_state_encode = None, - ... dropout = None, - ... n_layer = 1, - ... return_seq_2d = True, - ... name = 'seq2seq') - >>> net_out = Dense(net, n_units=10000, act=None, name='output') - >>> e_loss = tl.cost.cross_entropy_seq_with_mask(logits=net_out.outputs, target_seqs=target_seqs, input_mask=target_mask, return_details=False, name='cost') - >>> y = tf.nn.softmax(net_out.outputs) - >>> net_out.print_params(False) - - """ - - def __init__( - self, - net_encode_in, - net_decode_in, - cell_fn, #tf.nn.rnn_cell.LSTMCell, - cell_init_args=None, - n_hidden=256, - initializer=tf.compat.v1.initializers.random_uniform(-0.1, 0.1), - encode_sequence_length=None, - decode_sequence_length=None, - initial_state_encode=None, - initial_state_decode=None, - dropout=None, - n_layer=1, - return_seq_2d=False, - name='seq2seq', - ): - super(Seq2Seq, - self).__init__(prev_layer=[net_encode_in, net_decode_in], cell_init_args=cell_init_args, name=name) - - if self.cell_init_args: - self.cell_init_args['state_is_tuple'] = True # 'use_peepholes': True, - - if cell_fn is None: - raise ValueError("cell_fn cannot be set to None") - - if 'GRU' in cell_fn.__name__: - try: - cell_init_args.pop('state_is_tuple') - except Exception: - logging.warning("pop state_is_tuple fails.") - - logging.info( - "[*] Seq2Seq %s: n_hidden: %d cell_fn: %s dropout: %s n_layer: %d" % - (self.name, n_hidden, cell_fn.__name__, dropout, n_layer) + raise ValueError( + "target_mask_op: handling data_shape %s hasn't been implemented! " + "The shape of data should have 2 or 3 dims" % (data.get_shape()) ) - - with tf.compat.v1.variable_scope(name): - # tl.layers.set_name_reuse(reuse) - # network = InputLayer(self.inputs, name=name+'/input') - network_encode = DynamicRNN( - net_encode_in, cell_fn=cell_fn, cell_init_args=self.cell_init_args, n_hidden=n_hidden, - initializer=initializer, initial_state=initial_state_encode, dropout=dropout, n_layer=n_layer, - sequence_length=encode_sequence_length, return_last=False, return_seq_2d=True, name='encode' - ) - # vs.reuse_variables() - # tl.layers.set_name_reuse(True) - network_decode = DynamicRNN( - net_decode_in, cell_fn=cell_fn, cell_init_args=self.cell_init_args, n_hidden=n_hidden, - initializer=initializer, - initial_state=(network_encode.final_state if initial_state_decode is None else initial_state_decode), - dropout=dropout, n_layer=n_layer, sequence_length=decode_sequence_length, return_last=False, - return_seq_2d=return_seq_2d, name='decode' - ) - self.outputs = network_decode.outputs - - # rnn_variables = tf.get_collection(TF_GRAPHKEYS_VARIABLES, scope=vs.name) - - # Initial state - self.initial_state_encode = network_encode.initial_state - self.initial_state_decode = network_decode.initial_state - - # Final state - self.final_state_encode = network_encode.final_state - self.final_state_decode = network_decode.final_state - - # self.sequence_length = sequence_length - self._add_layers(network_encode.all_layers) - self._add_params(network_encode.all_params) - self._add_dropout_layers(network_encode.all_drop) - - self._add_layers(network_decode.all_layers) - self._add_params(network_decode.all_params) - self._add_dropout_layers(network_decode.all_drop) - - self._add_layers(self.outputs) diff --git a/tensorlayer/models/__init__.py b/tensorlayer/models/__init__.py index ec4b021d2..065b94885 100644 --- a/tensorlayer/models/__init__.py +++ b/tensorlayer/models/__init__.py @@ -7,3 +7,5 @@ from .mobilenetv1 import MobileNetV1 from .squeezenetv1 import SqueezeNetV1 from .vgg import * +from .seq2seq import Seq2seq +from .seq2seq_with_attention import Seq2seqLuongAttention diff --git a/tensorlayer/models/seq2seq.py b/tensorlayer/models/seq2seq.py index ca6931463..e0c20ef56 100644 --- a/tensorlayer/models/seq2seq.py +++ b/tensorlayer/models/seq2seq.py @@ -1,12 +1,14 @@ #! /usr/bin/python # -*- coding: utf-8 -*- +import numpy as np import tensorflow as tf import tensorlayer as tl -import numpy as np -from tensorlayer.models import Model from tensorlayer.layers import Dense, Dropout, Input from tensorlayer.layers.core import Layer +from tensorlayer.models import Model + +__all__ = ['Seq2seq'] class Seq2seq(Model): @@ -16,9 +18,9 @@ class Seq2seq(Model): ---------- decoder_seq_length: int The length of your target sequence - cell_enc : str, tf.function + cell_enc : TensorFlow cell function The RNN function cell for your encoder stack, e.g tf.keras.layers.GRUCell - cell_dec : str, tf.function + cell_dec : TensorFlow cell function The RNN function cell for your decoder stack, e.g. tf.keras.layers.GRUCell n_layer : int The number of your RNN layers for both encoder and decoder block diff --git a/tensorlayer/models/seq2seq_with_attention.py b/tensorlayer/models/seq2seq_with_attention.py new file mode 100644 index 000000000..d601e33c8 --- /dev/null +++ b/tensorlayer/models/seq2seq_with_attention.py @@ -0,0 +1,209 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- + +import numpy as np +import tensorflow as tf +import tensorlayer as tl +from tensorlayer.layers import Dense, Dropout, Input +from tensorlayer.layers.core import Layer +from tensorlayer.models import Model + +__all__ = ['Seq2seqLuongAttention'] + + +class Encoder(Layer): + + def __init__(self, hidden_size, cell, embedding_layer, name=None): + super(Encoder, self).__init__(name) + self.cell = cell(hidden_size) + self.hidden_size = hidden_size + self.embedding_layer = embedding_layer + self.build((None, None, self.embedding_layer.embedding_size)) + self._built = True + + def build(self, inputs_shape): + self.cell.build(input_shape=tuple(inputs_shape)) + self._built = True + if self._trainable_weights is None: + self._trainable_weights = list() + + for var in self.cell.trainable_variables: + self._trainable_weights.append(var) + + def forward(self, src_seq, initial_state=None): + + states = initial_state if initial_state is not None else self.cell.get_initial_state(src_seq) + encoding_hidden_states = list() + total_steps = src_seq.get_shape().as_list()[1] + for time_step in range(total_steps): + if not isinstance(states, list): + states = [states] + output, states = self.cell.call(src_seq[:, time_step, :], states, training=self.is_train) + encoding_hidden_states.append(states[0]) + return output, encoding_hidden_states, states[0] + + +class Decoder_Attention(Layer): + + def __init__(self, hidden_size, cell, embedding_layer, method, name=None): + super(Decoder_Attention, self).__init__(name) + self.cell = cell(hidden_size) + self.hidden_size = hidden_size + self.embedding_layer = embedding_layer + self.method = method + self.build((None, hidden_size + self.embedding_layer.embedding_size)) + self._built = True + + def build(self, inputs_shape): + self.cell.build(input_shape=tuple(inputs_shape)) + self._built = True + if self.method is "concat": + self.W = self._get_weights("W", shape=(2 * self.hidden_size, self.hidden_size)) + self.V = self._get_weights("V", shape=(self.hidden_size, 1)) + elif self.method is "general": + self.W = self._get_weights("W", shape=(self.hidden_size, self.hidden_size)) + if self._trainable_weights is None: + self._trainable_weights = list() + + for var in self.cell.trainable_variables: + self._trainable_weights.append(var) + + def score(self, encoding_hidden, hidden, method): + # encoding = [B, T, H] + # hidden = [B, H] + # combined = [B,T,2H] + if method is "concat": + # hidden = [B,H]->[B,1,H]->[B,T,H] + hidden = tf.expand_dims(hidden, 1) + hidden = tf.tile(hidden, [1, encoding_hidden.shape[1], 1]) + # combined = [B,T,2H] + combined = tf.concat([hidden, encoding_hidden], 2) + combined = tf.cast(combined, tf.float32) + score = tf.tensordot(combined, self.W, axes=[[2], [0]]) # score = [B,T,H] + score = tf.nn.tanh(score) # score = [B,T,H] + score = tf.tensordot(self.V, score, axes=[[0], [2]]) # score = [1,B,T] + score = tf.squeeze(score, axis=0) # score = [B,T] + + elif method is "dot": + # hidden = [B,H]->[B,H,1] + hidden = tf.expand_dims(hidden, 2) + score = tf.matmul(encoding_hidden, hidden) + score = tf.squeeze(score, axis=2) + elif method is "general": + # hidden = [B,H]->[B,H,1] + score = tf.matmul(hidden, self.W) + score = tf.expand_dims(score, 2) + score = tf.matmul(encoding_hidden, score) + score = tf.squeeze(score, axis=2) + + score = tf.nn.softmax(score, axis=-1) # score = [B,T] + return score + + def forward(self, dec_seq, enc_hiddens, last_hidden, method, return_last_state=False): + # dec_seq = [B, T_, V], enc_hiddens = [B, T, H], last_hidden = [B, H] + total_steps = dec_seq.get_shape().as_list()[1] + states = last_hidden + cell_outputs = list() + for time_step in range(total_steps): + attention_weights = self.score(enc_hiddens, last_hidden, method) + attention_weights = tf.expand_dims(attention_weights, 1) #[B, 1, T] + context = tf.matmul(attention_weights, enc_hiddens) #[B, 1, H] + context = tf.squeeze(context, 1) #[B, H] + inputs = tf.concat([dec_seq[:, time_step, :], context], 1) + if not isinstance(states, list): + states = [states] + cell_output, states = self.cell.call(inputs, states, training=self.is_train) + cell_outputs.append(cell_output) + last_hidden = states[0] + + cell_outputs = tf.convert_to_tensor(cell_outputs) + cell_outputs = tf.transpose(cell_outputs, perm=[1, 0, 2]) + if (return_last_state): + return cell_outputs, last_hidden + return cell_outputs + + +class Seq2seqLuongAttention(Model): + """Luong Attention-based Seq2Seq model. Implementation based on https://arxiv.org/pdf/1508.04025.pdf. + + Parameters + ---------- + hidden_size: int + The hidden size of both encoder and decoder RNN cells + cell : TensorFlow cell function + The RNN function cell for your encoder and decoder stack, e.g. tf.keras.layers.GRUCell + embedding_layer : tl.Layer + A embedding layer, e.g. tl.layers.Embedding(vocabulary_size=voc_size, embedding_size=emb_dim) + method : str + The three alternatives to calculate the attention scores, e.g. "dot", "general" and "concat" + name : str + The model name + + + Returns + ------- + static single layer attention-based Seq2Seq model. + """ + + def __init__(self, hidden_size, embedding_layer, cell, method, name=None): + super(Seq2seqLuongAttention, self).__init__(name) + self.enc_layer = Encoder(hidden_size, cell, embedding_layer) + self.dec_layer = Decoder_Attention(hidden_size, cell, embedding_layer, method=method) + self.embedding_layer = embedding_layer + self.dense_layer = tl.layers.Dense(n_units=self.embedding_layer.vocabulary_size, in_channels=hidden_size) + self.method = method + + def inference(self, src_seq, encoding_hidden_states, last_hidden_states, seq_length, sos): + """Inference mode""" + """ + Parameters + ---------- + src_seq : input tensor + The source sequences + encoding_hidden_states : a list of tensor + The list of encoder's hidden states at each time step + last_hidden_states: tensor + The last hidden_state from encoder + seq_length : int + The expected length of your predicted sequence. + sos : int + : The token of "start of sequence" + """ + + batch_size = src_seq.shape[0] + decoding = [[sos] for i in range(batch_size)] + dec_output = self.embedding_layer(decoding) + outputs = [[0] for i in range(batch_size)] + for step in range(seq_length): + dec_output, last_hidden_states = self.dec_layer( + dec_output, encoding_hidden_states, last_hidden_states, method=self.method, return_last_state=True + ) + dec_output = tf.reshape(dec_output, [-1, dec_output.shape[-1]]) + dec_output = self.dense_layer(dec_output) + dec_output = tf.reshape(dec_output, [batch_size, -1, dec_output.shape[-1]]) + dec_output = tf.argmax(dec_output, -1) + outputs = tf.concat([outputs, dec_output], 1) + dec_output = self.embedding_layer(dec_output) + + return outputs[:, 1:] + + def forward(self, inputs, seq_length=20, sos=None): + src_seq = inputs[0] + src_seq = self.embedding_layer(src_seq) + enc_output, encoding_hidden_states, last_hidden_states = self.enc_layer(src_seq) + encoding_hidden_states = tf.convert_to_tensor(encoding_hidden_states) + encoding_hidden_states = tf.transpose(encoding_hidden_states, perm=[1, 0, 2]) + last_hidden_states = tf.convert_to_tensor(last_hidden_states) + + if (self.is_train): + dec_seq = inputs[1] + dec_seq = self.embedding_layer(dec_seq) + dec_output = self.dec_layer(dec_seq, encoding_hidden_states, last_hidden_states, method=self.method) + batch_size = dec_output.shape[0] + dec_output = tf.reshape(dec_output, [-1, dec_output.shape[-1]]) + dec_output = self.dense_layer(dec_output) + dec_output = tf.reshape(dec_output, [batch_size, -1, dec_output.shape[-1]]) + else: + dec_output = self.inference(src_seq, encoding_hidden_states, last_hidden_states, seq_length, sos) + + return dec_output diff --git a/tests/layers/test_layers_recurrent.py b/tests/layers/test_layers_recurrent.py index 38c014ee3..65fbd2442 100644 --- a/tests/layers/test_layers_recurrent.py +++ b/tests/layers/test_layers_recurrent.py @@ -68,6 +68,33 @@ def test_basic_simplernn(self): if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_simplernn_class(self): + + inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) + rnnlayer = tl.layers.SimpleRNN( + units=self.hidden_size, dropout=0.1, return_last_output=True, return_seq_2d=False, return_last_state=True + ) + rnn, rnn_state = rnnlayer(inputs) + outputs = tl.layers.Dense(n_units=1)(rnn) + rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]]) + print(rnn_model) + + optimizer = tf.optimizers.Adam(learning_rate=0.01) + + rnn_model.train() + assert rnnlayer.is_train + + for epoch in range(50): + with tf.GradientTape() as tape: + pred_y, final_state = rnn_model(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) + + if (epoch + 1) % 10 == 0: + print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_simplernn2(self): inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) @@ -137,6 +164,39 @@ def forward(self, x): if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_simplernn_dynamic_class(self): + + class CustomisedModel(tl.models.Model): + + def __init__(self): + super(CustomisedModel, self).__init__() + self.rnnlayer = tl.layers.SimpleRNN( + units=8, dropout=0.1, in_channels=4, return_last_output=False, return_seq_2d=False, + return_last_state=False + ) + self.dense = tl.layers.Dense(in_channels=8, n_units=1) + + def forward(self, x): + z = self.rnnlayer(x) + z = self.dense(z[:, -1, :]) + return z + + rnn_model = CustomisedModel() + print(rnn_model) + optimizer = tf.optimizers.Adam(learning_rate=0.01) + rnn_model.train() + + for epoch in range(50): + with tf.GradientTape() as tape: + pred_y = rnn_model(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) + + if (epoch + 1) % 10 == 0: + print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_simplernn_dynamic_2(self): class CustomisedModel(tl.models.Model): @@ -238,6 +298,32 @@ def test_basic_lstmrnn(self): if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_lstmrnn_class(self): + + inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) + rnnlayer = tl.layers.LSTMRNN( + units=self.hidden_size, dropout=0.1, return_last_output=True, return_seq_2d=False, return_last_state=True + ) + rnn, rnn_state = rnnlayer(inputs) + outputs = tl.layers.Dense(n_units=1)(rnn) + rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0], rnn_state[1]]) + print(rnn_model) + + optimizer = tf.optimizers.Adam(learning_rate=0.01) + + rnn_model.train() + + for epoch in range(50): + with tf.GradientTape() as tape: + pred_y, final_h, final_c = rnn_model(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) + + if (epoch + 1) % 10 == 0: + print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_grurnn(self): inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) @@ -265,6 +351,32 @@ def test_basic_grurnn(self): if (epoch + 1) % 10 == 0: print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_grurnn_class(self): + + inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) + rnnlayer = tl.layers.GRURNN( + units=self.hidden_size, dropout=0.1, return_last_output=True, return_seq_2d=False, return_last_state=True + ) + rnn, rnn_state = rnnlayer(inputs) + outputs = tl.layers.Dense(n_units=1)(rnn) + rnn_model = tl.models.Model(inputs=inputs, outputs=[outputs, rnn_state[0]]) + print(rnn_model) + + optimizer = tf.optimizers.Adam(learning_rate=0.01) + + rnn_model.train() + + for epoch in range(50): + with tf.GradientTape() as tape: + pred_y, final_h = rnn_model(self.data_x) + loss = tl.cost.mean_squared_error(pred_y, self.data_y) + + gradients = tape.gradient(loss, rnn_model.trainable_weights) + optimizer.apply_gradients(zip(gradients, rnn_model.trainable_weights)) + + if (epoch + 1) % 10 == 0: + print("epoch %d, loss %f" % (epoch, loss)) + def test_basic_birnn_simplernncell(self): inputs = tl.layers.Input([self.batch_size, self.num_steps, self.embedding_size]) @@ -604,6 +716,59 @@ def test_sequence_length3(self): except Exception as e: print(e) + def test_target_mask_op(self): + fail_flag = False + data = [ + ['hello', 'world', '', '', ''], ['hello', 'world', 'tensorlayer', '', ''], + ['hello', 'world', 'tensorlayer', '2.0', ''] + ] + try: + tl.layers.target_mask_op(data, pad_val='') + fail_flag = True + except AttributeError as e: + print(e) + if fail_flag: + self.fail("Type error not raised") + + data = tf.convert_to_tensor(data, dtype=tf.string) + mask = tl.layers.target_mask_op(data, pad_val='') + print(mask) + + data = [[[1], [0], [0], [0], [0]], [[1], [2], [3], [0], [0]], [[1], [2], [0], [1], [0]]] + data = tf.convert_to_tensor(data, dtype=tf.float32) + mask = tl.layers.target_mask_op(data) + print(mask) + + data = [ + [[0, 0], [2, 2], [1, 2], [1, 2], [0, 0]], [[2, 3], [2, 4], [3, 2], [1, 0], [0, 0]], + [[3, 3], [0, 1], [5, 3], [1, 2], [0, 0]] + ] + data = tf.convert_to_tensor(data, dtype=tf.float32) + mask = tl.layers.target_mask_op(data) + print(mask) + + fail_flag = False + try: + data = [1, 2, 0, 0, 0] + data = tf.convert_to_tensor(data, dtype=tf.float32) + tl.layers.target_mask_op(data) + fail_flag = True + except ValueError as e: + print(e) + if fail_flag: + self.fail("Wrong data shape not detected.") + + fail_flag = False + try: + data = np.random.random([4, 2, 6, 2]) + data = tf.convert_to_tensor(data, dtype=tf.float32) + tl.layers.target_mask_op(data) + fail_flag = True + except ValueError as e: + print(e) + if fail_flag: + self.fail("Wrong data shape not detected.") + if __name__ == '__main__': diff --git a/tests/models/test_seq2seq_with_attention.py b/tests/models/test_seq2seq_with_attention.py new file mode 100644 index 000000000..d7dbeae34 --- /dev/null +++ b/tests/models/test_seq2seq_with_attention.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +import os +import unittest + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +import numpy as np +import tensorflow as tf +import tensorlayer as tl +from tqdm import tqdm +from sklearn.utils import shuffle +from tensorlayer.models.seq2seq_with_attention import Seq2seqLuongAttention +from tests.utils import CustomTestCase +from tensorlayer.cost import cross_entropy_seq + + +class Model_SEQ2SEQ_WITH_ATTENTION_Test(CustomTestCase): + + @classmethod + def setUpClass(cls): + + cls.batch_size = 16 + + cls.vocab_size = 200 + cls.embedding_size = 32 + cls.dec_seq_length = 5 + cls.pure_time = np.linspace(-1, 1, 21) + cls.pure_signal = 100 * np.sin(cls.pure_time) + cls.dataset = np.zeros((100, 21)) + for i in range(100): + noise = 100 + 1 * np.random.normal(0, 1, cls.pure_signal.shape) + cls.dataset[i] = cls.pure_signal + noise + cls.dataset = cls.dataset.astype(int) + np.random.shuffle(cls.dataset) + cls.trainX = cls.dataset[:80, :15] + cls.trainY = cls.dataset[:80, 15:] + cls.testX = cls.dataset[80:, :15] + cls.testY = cls.dataset[80:, 15:] + + cls.trainY[:, 0] = 0 # start_token == 0 + cls.testY[:, 0] = 0 # start_token == 0 + + # Parameters + cls.src_len = len(cls.trainX) + cls.tgt_len = len(cls.trainY) + + assert cls.src_len == cls.tgt_len + + cls.num_epochs = 500 + cls.n_step = cls.src_len // cls.batch_size + + @classmethod + def tearDownClass(cls): + pass + + def test_basic_simpleSeq2Seq(self): + + model_ = Seq2seqLuongAttention( + hidden_size=128, cell=tf.keras.layers.SimpleRNNCell, + embedding_layer=tl.layers.Embedding(vocabulary_size=self.vocab_size, + embedding_size=self.embedding_size), method='dot' + ) + optimizer = tf.optimizers.Adam(learning_rate=0.001) + + for epoch in range(self.num_epochs): + model_.train() + trainX, trainY = shuffle(self.trainX, self.trainY) + total_loss, n_iter = 0, 0 + for X, Y in tqdm(tl.iterate.minibatches(inputs=trainX, targets=trainY, batch_size=self.batch_size, + shuffle=False), total=self.n_step, + desc='Epoch[{}/{}]'.format(epoch + 1, self.num_epochs), leave=False): + dec_seq = Y[:, :-1] + target_seq = Y[:, 1:] + + with tf.GradientTape() as tape: + ## compute outputs + output = model_(inputs=[X, dec_seq]) + # print(output) + output = tf.reshape(output, [-1, self.vocab_size]) + + loss = cross_entropy_seq(logits=output, target_seqs=target_seq) + grad = tape.gradient(loss, model_.trainable_weights) + optimizer.apply_gradients(zip(grad, model_.trainable_weights)) + + total_loss += loss + n_iter += 1 + + model_.eval() + test_sample = self.testX[:5, :].tolist() # Can't capture the sequence. + top_n = 1 + for i in range(top_n): + prediction = model_([test_sample], seq_length=self.dec_seq_length, sos=0) + print("Prediction: >>>>> ", prediction, "\n Target: >>>>> ", self.testY[:5, 1:], "\n\n") + + # printing average loss after every epoch + print('Epoch [{}/{}]: loss {:.4f}'.format(epoch + 1, self.num_epochs, total_loss / n_iter)) + + +if __name__ == '__main__': + unittest.main()