From c30ac7b64838772eb578e23406780b5ce3e6e0a1 Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Thu, 20 Feb 2020 11:03:01 -0800 Subject: [PATCH 01/10] Update AttentionStateWrapper to work with Keras. --- tensorflow_addons/seq2seq/attention_wrapper.py | 12 ++++++++---- .../seq2seq/attention_wrapper_test.py | 18 +++++++++++++++++- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 3700c1055a..7af18e1522 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -2033,10 +2033,14 @@ def call(self, inputs, state, **kwargs): TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, AttentionWrapperState): - raise TypeError( - "Expected state to be instance of AttentionWrapperState. " - "Received type %s instead." % type(state) - ) + try: + state = AttentionWrapperState(*state) + except: + raise TypeError( + "Expected state to be instance of AttentionWrapperState or " + "values that can construct AttentionWrapperState. " + "Received type %s instead." % type(state) + ) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 2fd9d17ea0..46095c6e43 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -322,7 +322,7 @@ def testCustomAttentionLayer(self): ) self.assertEqual(initial_state.attention.shape[-1], self.units * 2) first_input = self.decoder_inputs[:, 0].astype(np.float32) - output, next_state = attention_wrapper(first_input, initial_state) + output, _ = attention_wrapper(first_input, initial_state) self.assertEqual(output.shape[-1], self.units * 2) def _testWithAttention( @@ -987,6 +987,22 @@ def testLuongMonotonicScaled(self): create_attention_kwargs=create_attention_kwargs, ) + def test_attention_state_with_keras_rnn(self): + # See https://github.com/tensorflow/addons/issues/1095. + cell = tf.keras.layers.LSTMCell(8) + + mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) + + cell = wrapper.AttentionWrapper( + cell=cell, attention_mechanism=mechanism) + + layer = tf.keras.layers.RNN(cell) + _ = layer(inputs=tf.ones((2, 4, 8))) + + # Make sure the explicit initial_state also works. + initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32) + _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state) + if __name__ == "__main__": tf.test.main() From 63d83d2f62f08c09dbbc6d31aacf22f86b09280b Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Thu, 20 Feb 2020 11:30:26 -0800 Subject: [PATCH 02/10] Fix lint errors. --- tensorflow_addons/seq2seq/attention_wrapper.py | 2 +- tensorflow_addons/seq2seq/attention_wrapper_test.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 7af18e1522..81f2188132 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -2035,7 +2035,7 @@ def call(self, inputs, state, **kwargs): if not isinstance(state, AttentionWrapperState): try: state = AttentionWrapperState(*state) - except: + except TypeError: raise TypeError( "Expected state to be instance of AttentionWrapperState or " "values that can construct AttentionWrapperState. " diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 46095c6e43..0defa72065 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -993,8 +993,7 @@ def test_attention_state_with_keras_rnn(self): mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) - cell = wrapper.AttentionWrapper( - cell=cell, attention_mechanism=mechanism) + cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism) layer = tf.keras.layers.RNN(cell) _ = layer(inputs=tf.ones((2, 4, 8))) From f107be50ddc7c9cf603b3065cfc1700f55b6dad1 Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Mon, 8 Jun 2020 13:27:02 -0700 Subject: [PATCH 03/10] Revert "Fix lint errors." This reverts commit 63d83d2f62f08c09dbbc6d31aacf22f86b09280b. --- tensorflow_addons/seq2seq/attention_wrapper.py | 2 +- tensorflow_addons/seq2seq/attention_wrapper_test.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 81f2188132..7af18e1522 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -2035,7 +2035,7 @@ def call(self, inputs, state, **kwargs): if not isinstance(state, AttentionWrapperState): try: state = AttentionWrapperState(*state) - except TypeError: + except: raise TypeError( "Expected state to be instance of AttentionWrapperState or " "values that can construct AttentionWrapperState. " diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 0defa72065..46095c6e43 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -993,7 +993,8 @@ def test_attention_state_with_keras_rnn(self): mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) - cell = wrapper.AttentionWrapper(cell=cell, attention_mechanism=mechanism) + cell = wrapper.AttentionWrapper( + cell=cell, attention_mechanism=mechanism) layer = tf.keras.layers.RNN(cell) _ = layer(inputs=tf.ones((2, 4, 8))) From f7b4fd95d326958593026e33e3e93690d353048c Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Mon, 8 Jun 2020 13:27:12 -0700 Subject: [PATCH 04/10] Revert "Update AttentionStateWrapper to work with Keras." This reverts commit c30ac7b64838772eb578e23406780b5ce3e6e0a1. --- tensorflow_addons/seq2seq/attention_wrapper.py | 12 ++++-------- .../seq2seq/attention_wrapper_test.py | 18 +----------------- 2 files changed, 5 insertions(+), 25 deletions(-) diff --git a/tensorflow_addons/seq2seq/attention_wrapper.py b/tensorflow_addons/seq2seq/attention_wrapper.py index 7af18e1522..3700c1055a 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper.py +++ b/tensorflow_addons/seq2seq/attention_wrapper.py @@ -2033,14 +2033,10 @@ def call(self, inputs, state, **kwargs): TypeError: If `state` is not an instance of `AttentionWrapperState`. """ if not isinstance(state, AttentionWrapperState): - try: - state = AttentionWrapperState(*state) - except: - raise TypeError( - "Expected state to be instance of AttentionWrapperState or " - "values that can construct AttentionWrapperState. " - "Received type %s instead." % type(state) - ) + raise TypeError( + "Expected state to be instance of AttentionWrapperState. " + "Received type %s instead." % type(state) + ) # Step 1: Calculate the true inputs to the cell based on the # previous attention value. diff --git a/tensorflow_addons/seq2seq/attention_wrapper_test.py b/tensorflow_addons/seq2seq/attention_wrapper_test.py index 46095c6e43..2fd9d17ea0 100644 --- a/tensorflow_addons/seq2seq/attention_wrapper_test.py +++ b/tensorflow_addons/seq2seq/attention_wrapper_test.py @@ -322,7 +322,7 @@ def testCustomAttentionLayer(self): ) self.assertEqual(initial_state.attention.shape[-1], self.units * 2) first_input = self.decoder_inputs[:, 0].astype(np.float32) - output, _ = attention_wrapper(first_input, initial_state) + output, next_state = attention_wrapper(first_input, initial_state) self.assertEqual(output.shape[-1], self.units * 2) def _testWithAttention( @@ -987,22 +987,6 @@ def testLuongMonotonicScaled(self): create_attention_kwargs=create_attention_kwargs, ) - def test_attention_state_with_keras_rnn(self): - # See https://github.com/tensorflow/addons/issues/1095. - cell = tf.keras.layers.LSTMCell(8) - - mechanism = wrapper.LuongAttention(units=8, memory=tf.ones((2, 4, 8))) - - cell = wrapper.AttentionWrapper( - cell=cell, attention_mechanism=mechanism) - - layer = tf.keras.layers.RNN(cell) - _ = layer(inputs=tf.ones((2, 4, 8))) - - # Make sure the explicit initial_state also works. - initial_state = cell.get_initial_state(batch_size=2, dtype=tf.float32) - _ = layer(inputs=tf.ones((2, 4, 8)), initial_state=initial_state) - if __name__ == "__main__": tf.test.main() From da5d0bec037e2e79ac627858bfdccabcad479d68 Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Sat, 20 Jun 2020 14:01:55 -0700 Subject: [PATCH 05/10] Move the tf.keras.layers.PeepholeLSTMCell to tfa This cell was exported under tf core API as an experimental since 2.0, but I think tfa should be a better place for this implementation. We are planing to deprecate and eventually remove the PeepholeLSTMCell in the tf core API once this is landed. --- tensorflow_addons/rnn/__init__.py | 1 + tensorflow_addons/rnn/cell.py | 76 ++++++++++++++++++++++++ tensorflow_addons/rnn/tests/cell_test.py | 34 +++++++++++ 3 files changed, 111 insertions(+) diff --git a/tensorflow_addons/rnn/__init__.py b/tensorflow_addons/rnn/__init__.py index 0a5ab832d3..cde3e86a46 100644 --- a/tensorflow_addons/rnn/__init__.py +++ b/tensorflow_addons/rnn/__init__.py @@ -18,3 +18,4 @@ from tensorflow_addons.rnn.cell import NASCell from tensorflow_addons.rnn.cell import LayerNormSimpleRNNCell from tensorflow_addons.rnn.cell import ESNCell +from tensorflow_addons.rnn.cell import PeepholeLSTMCell diff --git a/tensorflow_addons/rnn/cell.py b/tensorflow_addons/rnn/cell.py index 37ec7a9fb9..e7adacbb9e 100644 --- a/tensorflow_addons/rnn/cell.py +++ b/tensorflow_addons/rnn/cell.py @@ -799,3 +799,79 @@ def get_config(self): } base_config = super().get_config() return {**base_config, **config} + + +@tf.keras.utils.register_keras_serializable(package="Addons") +class PeepholeLSTMCell(tf.keras.layers.LSTMCell): + """Equivalent to `tf.keras.layers.LSTMCell` class but adds peephole connections. + + Peephole connections allow the gates to utilize the previous internal state as + well as the previous hidden state (which is what LSTMCell is limited to). + This allows PeepholeLSTMCell to better learn precise timings over LSTMCell. + + From [Gers et al., 2002]( + http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf): + + "We find that LSTM augmented by 'peephole connections' from its internal + cells to its multiplicative gates can learn the fine distinction between + sequences of spikes spaced either 50 or 49 time steps apart without the help + of any short training exemplars." + + The peephole implementation is based on: + + [Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf) + + Example: + + ```python + # Create 2 PeepholeLSTMCells + peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]] + # Create a layer composed sequentially of the peephole LSTM cells. + layer = RNN(peephole_lstm_cells) + input = keras.Input((timesteps, input_dim)) + output = layer(input) + ``` + """ + + def build(self, input_shape): + super(PeepholeLSTMCell, self).build(input_shape) + # The following are the weight matrices for the peephole connections. These + # are multiplied with the previous internal state during the computation of + # carry and output. + self.input_gate_peephole_weights = self.add_weight( + shape=(self.units,), + name='input_gate_peephole_weights', + initializer=self.kernel_initializer) + self.forget_gate_peephole_weights = self.add_weight( + shape=(self.units,), + name='forget_gate_peephole_weights', + initializer=self.kernel_initializer) + self.output_gate_peephole_weights = self.add_weight( + shape=(self.units,), + name='output_gate_peephole_weights', + initializer=self.kernel_initializer) + + def _compute_carry_and_output(self, x, h_tm1, c_tm1): + x_i, x_f, x_c, x_o = x + h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 + i = self.recurrent_activation( + x_i + tf.keras.backend.dot( + h_tm1_i, self.recurrent_kernel[:, :self.units]) + self.input_gate_peephole_weights * c_tm1) + f = self.recurrent_activation( + x_f + tf.keras.backend.dot( + h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + + self.forget_gate_peephole_weights * c_tm1) + c = f * c_tm1 + i * self.activation(x_c + tf.keras.backend.dot( + h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) + o = self.recurrent_activation( + x_o + tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + + self.output_gate_peephole_weights * c) + return c, o + + def _compute_carry_and_output_fused(self, z, c_tm1): + z0, z1, z2, z3 = z + i = self.recurrent_activation(z0 + self.input_gate_peephole_weights * c_tm1) + f = self.recurrent_activation(z1 + self.forget_gate_peephole_weights * c_tm1) + c = f * c_tm1 + i * self.activation(z2) + o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c) + return c, o diff --git a/tensorflow_addons/rnn/tests/cell_test.py b/tensorflow_addons/rnn/tests/cell_test.py index b2d76133c9..013c79ac51 100644 --- a/tensorflow_addons/rnn/tests/cell_test.py +++ b/tensorflow_addons/rnn/tests/cell_test.py @@ -558,3 +558,37 @@ def test_esn_config(): restored_cell = rnn_cell.ESNCell.from_config(config) restored_config = restored_cell.get_config() assert config == restored_config + + +def test_peephole_lstm_cell(): + def _run_cell(cell_fn, **kwargs): + inputs = tf.one_hot([1, 2, 3, 4], 4) + cell = cell_fn(5, **kwargs) + cell.build(inputs.shape) + initial_state = cell.get_initial_state( + inputs=inputs, batch_size=4, dtype=tf.float32) + inputs, _ = cell(inputs, initial_state) + output = inputs + return output + + tf.compat.v1.random.set_random_seed(12345) + # `recurrent_activation` kwarg is set to sigmoid as that is hardcoded into + # rnn_cell.LSTMCell. + first_implementation_output = _run_cell( + rnn_cell.PeepholeLSTMCell, + kernel_initializer='ones', + recurrent_activation='sigmoid', + implementation=1) + second_implementation_output = _run_cell( + rnn_cell.PeepholeLSTMCell, + kernel_initializer='ones', + recurrent_activation='sigmoid', + implementation=2) + tf_lstm_cell_output = _run_cell( + tf.compat.v1.nn.rnn_cell.LSTMCell, + use_peepholes=True, + initializer=tf.compat.v1.initializers.ones) + np.testing.assert_allclose(first_implementation_output, + second_implementation_output) + np.testing.assert_allclose(first_implementation_output, + tf_lstm_cell_output) From ca2f3b4d8334cbf8fa40c6232a7621225e41598b Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Mon, 22 Jun 2020 09:49:41 -0700 Subject: [PATCH 06/10] Fix code style --- tensorflow_addons/rnn/cell.py | 44 ++++++++++++++++-------- tensorflow_addons/rnn/tests/cell_test.py | 28 ++++++++------- 2 files changed, 45 insertions(+), 27 deletions(-) diff --git a/tensorflow_addons/rnn/cell.py b/tensorflow_addons/rnn/cell.py index e7adacbb9e..706eb54100 100644 --- a/tensorflow_addons/rnn/cell.py +++ b/tensorflow_addons/rnn/cell.py @@ -840,32 +840,46 @@ def build(self, input_shape): # carry and output. self.input_gate_peephole_weights = self.add_weight( shape=(self.units,), - name='input_gate_peephole_weights', - initializer=self.kernel_initializer) + name="input_gate_peephole_weights", + initializer=self.kernel_initializer, + ) self.forget_gate_peephole_weights = self.add_weight( shape=(self.units,), - name='forget_gate_peephole_weights', - initializer=self.kernel_initializer) + name="forget_gate_peephole_weights", + initializer=self.kernel_initializer, + ) self.output_gate_peephole_weights = self.add_weight( shape=(self.units,), - name='output_gate_peephole_weights', - initializer=self.kernel_initializer) + name="output_gate_peephole_weights", + initializer=self.kernel_initializer, + ) def _compute_carry_and_output(self, x, h_tm1, c_tm1): x_i, x_f, x_c, x_o = x h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1 i = self.recurrent_activation( - x_i + tf.keras.backend.dot( - h_tm1_i, self.recurrent_kernel[:, :self.units]) + self.input_gate_peephole_weights * c_tm1) + x_i + + tf.keras.backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units]) + + self.input_gate_peephole_weights * c_tm1 + ) f = self.recurrent_activation( - x_f + tf.keras.backend.dot( - h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]) + - self.forget_gate_peephole_weights * c_tm1) - c = f * c_tm1 + i * self.activation(x_c + tf.keras.backend.dot( - h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])) + x_f + + tf.keras.backend.dot( + h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2] + ) + + self.forget_gate_peephole_weights * c_tm1 + ) + c = f * c_tm1 + i * self.activation( + x_c + + tf.keras.backend.dot( + h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3] + ) + ) o = self.recurrent_activation( - x_o + tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]) + - self.output_gate_peephole_weights * c) + x_o + + tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]) + + self.output_gate_peephole_weights * c + ) return c, o def _compute_carry_and_output_fused(self, z, c_tm1): diff --git a/tensorflow_addons/rnn/tests/cell_test.py b/tensorflow_addons/rnn/tests/cell_test.py index 013c79ac51..00c9450062 100644 --- a/tensorflow_addons/rnn/tests/cell_test.py +++ b/tensorflow_addons/rnn/tests/cell_test.py @@ -566,7 +566,8 @@ def _run_cell(cell_fn, **kwargs): cell = cell_fn(5, **kwargs) cell.build(inputs.shape) initial_state = cell.get_initial_state( - inputs=inputs, batch_size=4, dtype=tf.float32) + inputs=inputs, batch_size=4, dtype=tf.float32 + ) inputs, _ = cell(inputs, initial_state) output = inputs return output @@ -576,19 +577,22 @@ def _run_cell(cell_fn, **kwargs): # rnn_cell.LSTMCell. first_implementation_output = _run_cell( rnn_cell.PeepholeLSTMCell, - kernel_initializer='ones', - recurrent_activation='sigmoid', - implementation=1) + kernel_initializer="ones", + recurrent_activation="sigmoid", + implementation=1, + ) second_implementation_output = _run_cell( rnn_cell.PeepholeLSTMCell, - kernel_initializer='ones', - recurrent_activation='sigmoid', - implementation=2) + kernel_initializer="ones", + recurrent_activation="sigmoid", + implementation=2, + ) tf_lstm_cell_output = _run_cell( tf.compat.v1.nn.rnn_cell.LSTMCell, use_peepholes=True, - initializer=tf.compat.v1.initializers.ones) - np.testing.assert_allclose(first_implementation_output, - second_implementation_output) - np.testing.assert_allclose(first_implementation_output, - tf_lstm_cell_output) + initializer=tf.compat.v1.initializers.ones, + ) + np.testing.assert_allclose( + first_implementation_output, second_implementation_output + ) + np.testing.assert_allclose(first_implementation_output, tf_lstm_cell_output) From 860814c760c303376ed756290a4a0ed29ceeb3e0 Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Wed, 24 Jun 2020 20:50:21 -0700 Subject: [PATCH 07/10] Update peephole lstm cell test with golden values. We removed the v1 compat API since TFA only works with TF v2. --- tensorflow_addons/rnn/tests/cell_test.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/tensorflow_addons/rnn/tests/cell_test.py b/tensorflow_addons/rnn/tests/cell_test.py index 00c9450062..a19a00356f 100644 --- a/tensorflow_addons/rnn/tests/cell_test.py +++ b/tensorflow_addons/rnn/tests/cell_test.py @@ -568,13 +568,10 @@ def _run_cell(cell_fn, **kwargs): initial_state = cell.get_initial_state( inputs=inputs, batch_size=4, dtype=tf.float32 ) - inputs, _ = cell(inputs, initial_state) - output = inputs + output, _ = cell(inputs, initial_state) return output - tf.compat.v1.random.set_random_seed(12345) - # `recurrent_activation` kwarg is set to sigmoid as that is hardcoded into - # rnn_cell.LSTMCell. + tf.random.set_seed(12345) first_implementation_output = _run_cell( rnn_cell.PeepholeLSTMCell, kernel_initializer="ones", @@ -587,12 +584,14 @@ def _run_cell(cell_fn, **kwargs): recurrent_activation="sigmoid", implementation=2, ) - tf_lstm_cell_output = _run_cell( - tf.compat.v1.nn.rnn_cell.LSTMCell, - use_peepholes=True, - initializer=tf.compat.v1.initializers.ones, + expected_output = np.asarray( + [[0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0., 0., 0., 0., 0.]], + dtype=np.float32 ) np.testing.assert_allclose( first_implementation_output, second_implementation_output ) - np.testing.assert_allclose(first_implementation_output, tf_lstm_cell_output) + np.testing.assert_allclose(first_implementation_output, expected_output, rtol=1e-6, atol=1e-6) From 4b64158a6ff7ed4f37b1966b45026ebcad1e965f Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Wed, 24 Jun 2020 20:58:23 -0700 Subject: [PATCH 08/10] Add PeepholeLSTMCell to the exception list for typehint check. The cell itself doesn't have __init__ and it inherit the __init__ from keras.LSTMcell, which doesn't have type hint yet. --- tools/testing/source_code_test.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tools/testing/source_code_test.py b/tools/testing/source_code_test.py index 7167b13c00..0e42454f5b 100644 --- a/tools/testing/source_code_test.py +++ b/tools/testing/source_code_test.py @@ -37,7 +37,9 @@ def test_api_typed(): tfa.text, ] # Files within this list will be exempt from verification. - exception_list = [] + exception_list = [ + tfa.rnn.PeepholeLSTMCell, + ] help_message = ( "You can also take a look at the section about it in the CONTRIBUTING.md:\n" "https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#about-type-hints" From deb57c329c92e844ad2eb2a9381f02f1526d009a Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Wed, 24 Jun 2020 21:04:01 -0700 Subject: [PATCH 09/10] Fix format. --- tensorflow_addons/rnn/tests/cell_test.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/tensorflow_addons/rnn/tests/cell_test.py b/tensorflow_addons/rnn/tests/cell_test.py index a19a00356f..be3fd13d5e 100644 --- a/tensorflow_addons/rnn/tests/cell_test.py +++ b/tensorflow_addons/rnn/tests/cell_test.py @@ -585,13 +585,17 @@ def _run_cell(cell_fn, **kwargs): implementation=2, ) expected_output = np.asarray( - [[0.417551, 0.417551, 0.417551, 0.417551, 0.417551], - [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], - [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], - [0., 0., 0., 0., 0.]], - dtype=np.float32 + [ + [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0.417551, 0.417551, 0.417551, 0.417551, 0.417551], + [0.0, 0.0, 0.0, 0.0, 0.0], + ], + dtype=np.float32, ) np.testing.assert_allclose( first_implementation_output, second_implementation_output ) - np.testing.assert_allclose(first_implementation_output, expected_output, rtol=1e-6, atol=1e-6) + np.testing.assert_allclose( + first_implementation_output, expected_output, rtol=1e-6, atol=1e-6 + ) From 5a99ecdd669b13f5a72dd08e769ba4c51a21ec5a Mon Sep 17 00:00:00 2001 From: qlzh727 Date: Thu, 25 Jun 2020 20:05:15 -0700 Subject: [PATCH 10/10] Update build method to be more aligned with py3 style. --- tensorflow_addons/rnn/cell.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow_addons/rnn/cell.py b/tensorflow_addons/rnn/cell.py index 706eb54100..6da33414ef 100644 --- a/tensorflow_addons/rnn/cell.py +++ b/tensorflow_addons/rnn/cell.py @@ -834,7 +834,7 @@ class PeepholeLSTMCell(tf.keras.layers.LSTMCell): """ def build(self, input_shape): - super(PeepholeLSTMCell, self).build(input_shape) + super().build(input_shape) # The following are the weight matrices for the peephole connections. These # are multiplied with the previous internal state during the computation of # carry and output.