Skip to content

Move the tf.keras.layers.PeepholeLSTMCell to tfa #1944

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Jun 26, 2020
1 change: 1 addition & 0 deletions tensorflow_addons/rnn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@
from tensorflow_addons.rnn.cell import NASCell
from tensorflow_addons.rnn.cell import LayerNormSimpleRNNCell
from tensorflow_addons.rnn.cell import ESNCell
from tensorflow_addons.rnn.cell import PeepholeLSTMCell
90 changes: 90 additions & 0 deletions tensorflow_addons/rnn/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,3 +799,93 @@ def get_config(self):
}
base_config = super().get_config()
return {**base_config, **config}


@tf.keras.utils.register_keras_serializable(package="Addons")
class PeepholeLSTMCell(tf.keras.layers.LSTMCell):
"""Equivalent to `tf.keras.layers.LSTMCell` class but adds peephole connections.

Peephole connections allow the gates to utilize the previous internal state as
well as the previous hidden state (which is what LSTMCell is limited to).
This allows PeepholeLSTMCell to better learn precise timings over LSTMCell.

From [Gers et al., 2002](
http://www.jmlr.org/papers/volume3/gers02a/gers02a.pdf):

"We find that LSTM augmented by 'peephole connections' from its internal
cells to its multiplicative gates can learn the fine distinction between
sequences of spikes spaced either 50 or 49 time steps apart without the help
of any short training exemplars."

The peephole implementation is based on:

[Sak et al., 2014](https://research.google.com/pubs/archive/43905.pdf)

Example:

```python
# Create 2 PeepholeLSTMCells
peephole_lstm_cells = [PeepholeLSTMCell(size) for size in [128, 256]]
# Create a layer composed sequentially of the peephole LSTM cells.
layer = RNN(peephole_lstm_cells)
input = keras.Input((timesteps, input_dim))
output = layer(input)
```
"""

def build(self, input_shape):
super().build(input_shape)
# The following are the weight matrices for the peephole connections. These
# are multiplied with the previous internal state during the computation of
# carry and output.
self.input_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name="input_gate_peephole_weights",
initializer=self.kernel_initializer,
)
self.forget_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name="forget_gate_peephole_weights",
initializer=self.kernel_initializer,
)
self.output_gate_peephole_weights = self.add_weight(
shape=(self.units,),
name="output_gate_peephole_weights",
initializer=self.kernel_initializer,
)

def _compute_carry_and_output(self, x, h_tm1, c_tm1):
x_i, x_f, x_c, x_o = x
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
i = self.recurrent_activation(
x_i
+ tf.keras.backend.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
+ self.input_gate_peephole_weights * c_tm1
)
f = self.recurrent_activation(
x_f
+ tf.keras.backend.dot(
h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2]
)
+ self.forget_gate_peephole_weights * c_tm1
)
c = f * c_tm1 + i * self.activation(
x_c
+ tf.keras.backend.dot(
h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3]
)
)
o = self.recurrent_activation(
x_o
+ tf.keras.backend.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
+ self.output_gate_peephole_weights * c
)
return c, o

def _compute_carry_and_output_fused(self, z, c_tm1):
z0, z1, z2, z3 = z
i = self.recurrent_activation(z0 + self.input_gate_peephole_weights * c_tm1)
f = self.recurrent_activation(z1 + self.forget_gate_peephole_weights * c_tm1)
c = f * c_tm1 + i * self.activation(z2)
o = self.recurrent_activation(z3 + self.output_gate_peephole_weights * c)
return c, o
41 changes: 41 additions & 0 deletions tensorflow_addons/rnn/tests/cell_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,3 +558,44 @@ def test_esn_config():
restored_cell = rnn_cell.ESNCell.from_config(config)
restored_config = restored_cell.get_config()
assert config == restored_config


def test_peephole_lstm_cell():
def _run_cell(cell_fn, **kwargs):
inputs = tf.one_hot([1, 2, 3, 4], 4)
cell = cell_fn(5, **kwargs)
cell.build(inputs.shape)
initial_state = cell.get_initial_state(
inputs=inputs, batch_size=4, dtype=tf.float32
)
output, _ = cell(inputs, initial_state)
return output

tf.random.set_seed(12345)
first_implementation_output = _run_cell(
rnn_cell.PeepholeLSTMCell,
kernel_initializer="ones",
recurrent_activation="sigmoid",
implementation=1,
)
second_implementation_output = _run_cell(
rnn_cell.PeepholeLSTMCell,
kernel_initializer="ones",
recurrent_activation="sigmoid",
implementation=2,
)
expected_output = np.asarray(
[
[0.417551, 0.417551, 0.417551, 0.417551, 0.417551],
[0.417551, 0.417551, 0.417551, 0.417551, 0.417551],
[0.417551, 0.417551, 0.417551, 0.417551, 0.417551],
[0.0, 0.0, 0.0, 0.0, 0.0],
],
dtype=np.float32,
)
np.testing.assert_allclose(
first_implementation_output, second_implementation_output
)
np.testing.assert_allclose(
first_implementation_output, expected_output, rtol=1e-6, atol=1e-6
)
4 changes: 3 additions & 1 deletion tools/testing/source_code_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,9 @@ def test_api_typed():
tfa.text,
]
# Files within this list will be exempt from verification.
exception_list = []
exception_list = [
tfa.rnn.PeepholeLSTMCell,
]
help_message = (
"You can also take a look at the section about it in the CONTRIBUTING.md:\n"
"https://github.com/tensorflow/addons/blob/master/CONTRIBUTING.md#about-type-hints"
Expand Down