diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index c397917324be..d3c1119fecd1 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- from __future__ import absolute_import import numpy as np -import functools import warnings from .. import backend as K @@ -200,7 +199,9 @@ class RNN(Layer): # Arguments cell: A RNN cell instance. A RNN cell is a class that has: - a `call(input_at_t, states_at_t)` method, returning - `(output_at_t, states_at_t_plus_1)`. + `(output_at_t, states_at_t_plus_1)`. The call method of the + cell can also take the optional argument `constants`, see + section "Note on passing external constants" below. - a `state_size` attribute. This can be a single integer (single state) in which case it is the size of the recurrent state @@ -292,6 +293,14 @@ class RNN(Layer): `states` should be a numpy array or list of numpy arrays representing the initial state of the RNN layer. + # Note on passing external constants to RNNs + You can pass "external" constants to the cell using the `constants` + keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This + requires that the `cell.call` method accepts the same keyword argument + `constants`. Such constants can be used to condition the cell + transformation on additional static inputs (not changing over time), + a.k.a. an attention mechanism. + # Examples ```python @@ -363,12 +372,10 @@ def __init__(self, cell, self.supports_masking = True self.input_spec = [InputSpec(ndim=3)] - if hasattr(self.cell.state_size, '__len__'): - self.state_spec = [InputSpec(shape=(None, dim)) - for dim in self.cell.state_size] - else: - self.state_spec = InputSpec(shape=(None, self.cell.state_size)) + self.state_spec = None self._states = None + self.constants_spec = None + self._num_constants = None @property def states(self): @@ -415,6 +422,13 @@ def compute_mask(self, inputs, mask): return output_mask def build(self, input_shape): + # Note input_shape will be list of shapes of initial states and + # constants if these are passed in __call__. + if self._num_constants is not None: + constants_shape = input_shape[-self._num_constants:] + else: + constants_shape = None + if isinstance(input_shape, list): input_shape = input_shape[0] @@ -422,12 +436,32 @@ def build(self, input_shape): input_dim = input_shape[-1] self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim)) - if self.stateful: - self.reset_states() - + # allow cell (if layer) to build before we set or validate state_spec if isinstance(self.cell, Layer): step_input_shape = (input_shape[0],) + input_shape[2:] - self.cell.build(step_input_shape) + if constants_shape is not None: + self.cell.build([step_input_shape] + constants_shape) + else: + self.cell.build(step_input_shape) + + # set or validate state_spec + if hasattr(self.cell.state_size, '__len__'): + state_size = list(self.cell.state_size) + else: + state_size = [self.cell.state_size] + + if self.state_spec is not None: + # initial_state was passed in call, check compatibility + if not [spec.shape[-1] for spec in self.state_spec] == state_size: + raise ValueError( + 'an initial_state was passed that is not compatible with' + ' cell.state_size, state_spec: {}, cell.state_size:' + ' {}'.format(self.state_spec, self.cell.state_size)) + else: + self.state_spec = [InputSpec(shape=(None, dim)) + for dim in state_size] + if self.stateful: + self.reset_states() def get_initial_state(self, inputs): # build an all-zero tensor of shape (samples, output_dim) @@ -440,62 +474,65 @@ def get_initial_state(self, inputs): else: return [K.tile(initial_state, [1, self.cell.state_size])] - def __call__(self, inputs, initial_state=None, **kwargs): - # If there are multiple inputs, then - # they should be the main input and `initial_state` - # e.g. when loading model from file - if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None: - initial_state = inputs[1:] - inputs = inputs[0] + def __call__(self, inputs, initial_state=None, constants=None, **kwargs): + inputs, initial_state, constants = self._standardize_args( + inputs, initial_state, constants) - # If `initial_state` is specified, - # and if it a Keras tensor, - # then add it to the inputs and temporarily - # modify the input spec to include the state. - if initial_state is None: + if initial_state is None and constants is None: return super(RNN, self).__call__(inputs, **kwargs) - if not isinstance(initial_state, (list, tuple)): - initial_state = [initial_state] + # If any of `initial_state` or `constants` are specified and are Keras + # tensors, then add them to the inputs and temporarily modify the + # input_spec to include them. - is_keras_tensor = hasattr(initial_state[0], '_keras_history') - for tensor in initial_state: + additional_inputs = [] + additional_specs = [] + if initial_state is not None: + kwargs['initial_state'] = initial_state + additional_inputs += initial_state + self.state_spec = [InputSpec(shape=K.int_shape(state)) + for state in initial_state] + additional_specs += self.state_spec + if constants is not None: + kwargs['constants'] = constants + additional_inputs += constants + self.constants_spec = [InputSpec(shape=K.int_shape(constant)) + for constant in constants] + self._num_constants = len(constants) + additional_specs += self.constants_spec + # at this point additional_inputs cannot be empty + is_keras_tensor = hasattr(additional_inputs[0], '_keras_history') + for tensor in additional_inputs: if hasattr(tensor, '_keras_history') != is_keras_tensor: - raise ValueError('The initial state of an RNN layer cannot be' - ' specified with a mix of Keras tensors and' - ' non-Keras tensors') + raise ValueError('The initial state or constants of an RNN' + ' layer cannot be specified with a mix of' + ' Keras tensors and non-Keras tensors') if is_keras_tensor: - # Compute the full input spec, including state - input_spec = self.input_spec - state_spec = self.state_spec - if not isinstance(input_spec, list): - input_spec = [input_spec] - if not isinstance(state_spec, list): - state_spec = [state_spec] - self.input_spec = input_spec + state_spec - - # Compute the full inputs, including state - inputs = [inputs] + list(initial_state) - - # Perform the call - output = super(RNN, self).__call__(inputs, **kwargs) - - # Restore original input spec - self.input_spec = input_spec + # Compute the full input spec, including state and constants + full_input = [inputs] + additional_inputs + full_input_spec = self.input_spec + additional_specs + # Perform the call with temporarily replaced input_spec + original_input_spec = self.input_spec + self.input_spec = full_input_spec + output = super(RNN, self).__call__(full_input, **kwargs) + self.input_spec = original_input_spec return output else: - kwargs['initial_state'] = initial_state return super(RNN, self).__call__(inputs, **kwargs) - def call(self, inputs, mask=None, training=None, initial_state=None): + def call(self, + inputs, + mask=None, + training=None, + initial_state=None, + constants=None): # input shape: `(samples, time (padded with zeros), input_dim)` # note that the .build() method of subclasses MUST define # self.input_spec and self.state_spec with complete input shapes. if isinstance(inputs, list): - initial_state = inputs[1:] inputs = inputs[0] - elif initial_state is not None: + if initial_state is not None: pass elif self.stateful: initial_state = self.states @@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None): 'the time dimension by passing a `shape` ' 'or `batch_shape` argument to your Input layer.') + kwargs = {} if has_arg(self.cell.call, 'training'): - step = functools.partial(self.cell.call, training=training) + kwargs['training'] = training + + if constants: + if not has_arg(self.cell.call, 'constants'): + raise ValueError('RNN cell does not support constants') + + def step(inputs, states): + constants = states[-self._num_constants:] + states = states[:-self._num_constants] + return self.cell.call(inputs, states, constants=constants, + **kwargs) else: - step = self.cell.call + def step(inputs, states): + return self.cell.call(inputs, states, **kwargs) + last_output, outputs, states = K.rnn(step, inputs, initial_state, + constants=constants, go_backwards=self.go_backwards, mask=mask, unroll=self.unroll, @@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None): else: return output + def _standardize_args(self, inputs, initial_state, constants): + """Brings the arguments of `__call__` that can contain input tensors to + standard format. + + When running a model loaded from file, the input tensors + `initial_state` and `constants` can be passed to `RNN.__call__` as part + of `inputs` instead of by the dedicated keyword arguments. This method + makes sure the arguments are separated and that `initial_state` and + `constants` are lists of tensors (or None). + + # Arguments + inputs: tensor or list/tuple of tensors + initial_state: tensor or list of tensors or None + constants: tensor or list of tensors or None + + # Returns + inputs: tensor + initial_state: list of tensors or None + constants: list of tensors or None + """ + if isinstance(inputs, list): + assert initial_state is None and constants is None + if self._num_constants is not None: + constants = inputs[-self._num_constants:] + inputs = inputs[:-self._num_constants] + if len(inputs) > 1: + initial_state = inputs[1:] + inputs = inputs[0] + + def to_list_or_none(x): + if x is None or isinstance(x, list): + return x + if isinstance(x, tuple): + return list(x) + return [x] + + initial_state = to_list_or_none(initial_state) + constants = to_list_or_none(constants) + + return inputs, initial_state, constants + def reset_states(self, states=None): if not self.stateful: raise AttributeError('Layer must be stateful.') @@ -618,6 +710,9 @@ def get_config(self): 'go_backwards': self.go_backwards, 'stateful': self.stateful, 'unroll': self.unroll} + if self._num_constants is not None: + config['num_constants'] = self._num_constants + cell_config = self.cell.get_config() config['cell'] = {'class_name': self.cell.__class__.__name__, 'config': cell_config} @@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None): from . import deserialize as deserialize_layer cell = deserialize_layer(config.pop('cell'), custom_objects=custom_objects) - return cls(cell, **config) + num_constants = config.pop('num_constants', None) + layer = cls(cell, **config) + layer._num_constants = num_constants + return layer @property def trainable_weights(self): diff --git a/tests/keras/layers/recurrent_test.py b/tests/keras/layers/recurrent_test.py index 2dce96ee67fa..19d318a060a3 100644 --- a/tests/keras/layers/recurrent_test.py +++ b/tests/keras/layers/recurrent_test.py @@ -568,5 +568,174 @@ def test_batch_size_equal_one(layer_class): model.train_on_batch(x, y) +def test_rnn_cell_with_constants_layer(): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + cell = RNNCellWithConstants(32) + layer = recurrent.RNN(cell) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = recurrent.RNN.from_config(config.copy()) + y = layer(x, constants=c) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, c_np]) + assert_allclose(y_np, y_np_2, atol=1e-4) + + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = recurrent.RNN.from_config(config.copy()) + y = layer([x, c]) + model = keras.models.Model([x, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, c_np]) + assert_allclose(y_np, y_np_3, atol=1e-4) + + +def test_rnn_cell_with_constants_layer_passing_initial_state(): + + class RNNCellWithConstants(keras.layers.Layer): + + def __init__(self, units, **kwargs): + self.units = units + self.state_size = units + super(RNNCellWithConstants, self).__init__(**kwargs) + + def build(self, input_shape): + if not isinstance(input_shape, list): + raise TypeError('expects constants shape') + [input_shape, constant_shape] = input_shape + # will (and should) raise if more than one constant passed + + self.input_kernel = self.add_weight( + shape=(input_shape[-1], self.units), + initializer='uniform', + name='kernel') + self.recurrent_kernel = self.add_weight( + shape=(self.units, self.units), + initializer='uniform', + name='recurrent_kernel') + self.constant_kernel = self.add_weight( + shape=(constant_shape[-1], self.units), + initializer='uniform', + name='constant_kernel') + self.built = True + + def call(self, inputs, states, constants): + [prev_output] = states + [constant] = constants + h_input = keras.backend.dot(inputs, self.input_kernel) + h_state = keras.backend.dot(prev_output, self.recurrent_kernel) + h_const = keras.backend.dot(constant, self.constant_kernel) + output = h_input + h_state + h_const + return output, [output] + + def get_config(self): + config = {'units': self.units} + base_config = super(RNNCellWithConstants, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + + # Test basic case. + x = keras.Input((None, 5)) + c = keras.Input((3,)) + s = keras.Input((32,)) + cell = RNNCellWithConstants(32) + layer = recurrent.RNN(cell) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.compile(optimizer='rmsprop', loss='mse') + model.train_on_batch( + [np.zeros((6, 5, 5)), np.zeros((6, 32)), np.zeros((6, 3))], + np.zeros((6, 32)) + ) + + # Test basic case serialization. + x_np = np.random.random((6, 5, 5)) + s_np = np.random.random((6, 32)) + c_np = np.random.random((6, 3)) + y_np = model.predict([x_np, s_np, c_np]) + weights = model.get_weights() + config = layer.get_config() + custom_objects = {'RNNCellWithConstants': RNNCellWithConstants} + with keras.utils.CustomObjectScope(custom_objects): + layer = recurrent.RNN.from_config(config.copy()) + y = layer(x, initial_state=s, constants=c) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_2 = model.predict([x_np, s_np, c_np]) + assert_allclose(y_np, y_np_2, atol=1e-4) + + # verify that state is used + y_np_2_different_s = model.predict([x_np, s_np + 10., c_np]) + with pytest.raises(AssertionError): + assert_allclose(y_np, y_np_2_different_s, atol=1e-4) + + # test flat list inputs + with keras.utils.CustomObjectScope(custom_objects): + layer = recurrent.RNN.from_config(config.copy()) + y = layer([x, s, c]) + model = keras.models.Model([x, s, c], y) + model.set_weights(weights) + y_np_3 = model.predict([x_np, s_np, c_np]) + assert_allclose(y_np, y_np_3, atol=1e-4) + + if __name__ == '__main__': pytest.main([__file__])