Skip to content

Commit 3f148e4

Browse files
andhusfchollet
authored andcommitted
Recurrent Attention API Additions (#7980)
* Added support for passing external constants to RNN, which will pass them on to the cell * Added class for allowing functional composition of RNN Cells, supporting constants * put back accidentally commented out recurrent tests * added basic example of functional cell * new class AttentionRNN * restored RNN layer * renamed constants to attended in FunctionRNNCell, avoided duplicating outputs in wrapped model * minor clean-up of docs * Minor cleanup & improvments in docs, fixed PEP breaking formatting in attention test * removed FunctionalRNNCell and AttentionRNN, added back support for constants in RNN * fixed PEP8 violations * fixed minor review comments * added test case for when both inital_state and constants are passed to RNN.__call__
1 parent 3c69f98 commit 3f148e4

File tree

2 files changed

+321
-54
lines changed

2 files changed

+321
-54
lines changed

keras/layers/recurrent.py

Lines changed: 152 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import absolute_import
33
import numpy as np
4-
import functools
54
import warnings
65

76
from .. import backend as K
@@ -200,7 +199,9 @@ class RNN(Layer):
200199
# Arguments
201200
cell: A RNN cell instance. A RNN cell is a class that has:
202201
- a `call(input_at_t, states_at_t)` method, returning
203-
`(output_at_t, states_at_t_plus_1)`.
202+
`(output_at_t, states_at_t_plus_1)`. The call method of the
203+
cell can also take the optional argument `constants`, see
204+
section "Note on passing external constants" below.
204205
- a `state_size` attribute. This can be a single integer
205206
(single state) in which case it is
206207
the size of the recurrent state
@@ -292,6 +293,14 @@ class RNN(Layer):
292293
`states` should be a numpy array or list of numpy arrays representing
293294
the initial state of the RNN layer.
294295
296+
# Note on passing external constants to RNNs
297+
You can pass "external" constants to the cell using the `constants`
298+
keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
299+
requires that the `cell.call` method accepts the same keyword argument
300+
`constants`. Such constants can be used to condition the cell
301+
transformation on additional static inputs (not changing over time),
302+
a.k.a. an attention mechanism.
303+
295304
# Examples
296305
297306
```python
@@ -363,12 +372,10 @@ def __init__(self, cell,
363372

364373
self.supports_masking = True
365374
self.input_spec = [InputSpec(ndim=3)]
366-
if hasattr(self.cell.state_size, '__len__'):
367-
self.state_spec = [InputSpec(shape=(None, dim))
368-
for dim in self.cell.state_size]
369-
else:
370-
self.state_spec = InputSpec(shape=(None, self.cell.state_size))
375+
self.state_spec = None
371376
self._states = None
377+
self.constants_spec = None
378+
self._num_constants = None
372379

373380
@property
374381
def states(self):
@@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask):
415422
return output_mask
416423

417424
def build(self, input_shape):
425+
# Note input_shape will be list of shapes of initial states and
426+
# constants if these are passed in __call__.
427+
if self._num_constants is not None:
428+
constants_shape = input_shape[-self._num_constants:]
429+
else:
430+
constants_shape = None
431+
418432
if isinstance(input_shape, list):
419433
input_shape = input_shape[0]
420434

421435
batch_size = input_shape[0] if self.stateful else None
422436
input_dim = input_shape[-1]
423437
self.input_spec[0] = InputSpec(shape=(batch_size, None, input_dim))
424438

425-
if self.stateful:
426-
self.reset_states()
427-
439+
# allow cell (if layer) to build before we set or validate state_spec
428440
if isinstance(self.cell, Layer):
429441
step_input_shape = (input_shape[0],) + input_shape[2:]
430-
self.cell.build(step_input_shape)
442+
if constants_shape is not None:
443+
self.cell.build([step_input_shape] + constants_shape)
444+
else:
445+
self.cell.build(step_input_shape)
446+
447+
# set or validate state_spec
448+
if hasattr(self.cell.state_size, '__len__'):
449+
state_size = list(self.cell.state_size)
450+
else:
451+
state_size = [self.cell.state_size]
452+
453+
if self.state_spec is not None:
454+
# initial_state was passed in call, check compatibility
455+
if not [spec.shape[-1] for spec in self.state_spec] == state_size:
456+
raise ValueError(
457+
'an initial_state was passed that is not compatible with'
458+
' cell.state_size, state_spec: {}, cell.state_size:'
459+
' {}'.format(self.state_spec, self.cell.state_size))
460+
else:
461+
self.state_spec = [InputSpec(shape=(None, dim))
462+
for dim in state_size]
463+
if self.stateful:
464+
self.reset_states()
431465

432466
def get_initial_state(self, inputs):
433467
# build an all-zero tensor of shape (samples, output_dim)
@@ -440,62 +474,65 @@ def get_initial_state(self, inputs):
440474
else:
441475
return [K.tile(initial_state, [1, self.cell.state_size])]
442476

443-
def __call__(self, inputs, initial_state=None, **kwargs):
444-
# If there are multiple inputs, then
445-
# they should be the main input and `initial_state`
446-
# e.g. when loading model from file
447-
if isinstance(inputs, (list, tuple)) and len(inputs) > 1 and initial_state is None:
448-
initial_state = inputs[1:]
449-
inputs = inputs[0]
477+
def __call__(self, inputs, initial_state=None, constants=None, **kwargs):
478+
inputs, initial_state, constants = self._standardize_args(
479+
inputs, initial_state, constants)
450480

451-
# If `initial_state` is specified,
452-
# and if it a Keras tensor,
453-
# then add it to the inputs and temporarily
454-
# modify the input spec to include the state.
455-
if initial_state is None:
481+
if initial_state is None and constants is None:
456482
return super(RNN, self).__call__(inputs, **kwargs)
457483

458-
if not isinstance(initial_state, (list, tuple)):
459-
initial_state = [initial_state]
484+
# If any of `initial_state` or `constants` are specified and are Keras
485+
# tensors, then add them to the inputs and temporarily modify the
486+
# input_spec to include them.
460487

461-
is_keras_tensor = hasattr(initial_state[0], '_keras_history')
462-
for tensor in initial_state:
488+
additional_inputs = []
489+
additional_specs = []
490+
if initial_state is not None:
491+
kwargs['initial_state'] = initial_state
492+
additional_inputs += initial_state
493+
self.state_spec = [InputSpec(shape=K.int_shape(state))
494+
for state in initial_state]
495+
additional_specs += self.state_spec
496+
if constants is not None:
497+
kwargs['constants'] = constants
498+
additional_inputs += constants
499+
self.constants_spec = [InputSpec(shape=K.int_shape(constant))
500+
for constant in constants]
501+
self._num_constants = len(constants)
502+
additional_specs += self.constants_spec
503+
# at this point additional_inputs cannot be empty
504+
is_keras_tensor = hasattr(additional_inputs[0], '_keras_history')
505+
for tensor in additional_inputs:
463506
if hasattr(tensor, '_keras_history') != is_keras_tensor:
464-
raise ValueError('The initial state of an RNN layer cannot be'
465-
' specified with a mix of Keras tensors and'
466-
' non-Keras tensors')
507+
raise ValueError('The initial state or constants of an RNN'
508+
' layer cannot be specified with a mix of'
509+
' Keras tensors and non-Keras tensors')
467510

468511
if is_keras_tensor:
469-
# Compute the full input spec, including state
470-
input_spec = self.input_spec
471-
state_spec = self.state_spec
472-
if not isinstance(input_spec, list):
473-
input_spec = [input_spec]
474-
if not isinstance(state_spec, list):
475-
state_spec = [state_spec]
476-
self.input_spec = input_spec + state_spec
477-
478-
# Compute the full inputs, including state
479-
inputs = [inputs] + list(initial_state)
480-
481-
# Perform the call
482-
output = super(RNN, self).__call__(inputs, **kwargs)
483-
484-
# Restore original input spec
485-
self.input_spec = input_spec
512+
# Compute the full input spec, including state and constants
513+
full_input = [inputs] + additional_inputs
514+
full_input_spec = self.input_spec + additional_specs
515+
# Perform the call with temporarily replaced input_spec
516+
original_input_spec = self.input_spec
517+
self.input_spec = full_input_spec
518+
output = super(RNN, self).__call__(full_input, **kwargs)
519+
self.input_spec = original_input_spec
486520
return output
487521
else:
488-
kwargs['initial_state'] = initial_state
489522
return super(RNN, self).__call__(inputs, **kwargs)
490523

491-
def call(self, inputs, mask=None, training=None, initial_state=None):
524+
def call(self,
525+
inputs,
526+
mask=None,
527+
training=None,
528+
initial_state=None,
529+
constants=None):
492530
# input shape: `(samples, time (padded with zeros), input_dim)`
493531
# note that the .build() method of subclasses MUST define
494532
# self.input_spec and self.state_spec with complete input shapes.
495533
if isinstance(inputs, list):
496-
initial_state = inputs[1:]
497534
inputs = inputs[0]
498-
elif initial_state is not None:
535+
if initial_state is not None:
499536
pass
500537
elif self.stateful:
501538
initial_state = self.states
@@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
525562
'the time dimension by passing a `shape` '
526563
'or `batch_shape` argument to your Input layer.')
527564

565+
kwargs = {}
528566
if has_arg(self.cell.call, 'training'):
529-
step = functools.partial(self.cell.call, training=training)
567+
kwargs['training'] = training
568+
569+
if constants:
570+
if not has_arg(self.cell.call, 'constants'):
571+
raise ValueError('RNN cell does not support constants')
572+
573+
def step(inputs, states):
574+
constants = states[-self._num_constants:]
575+
states = states[:-self._num_constants]
576+
return self.cell.call(inputs, states, constants=constants,
577+
**kwargs)
530578
else:
531-
step = self.cell.call
579+
def step(inputs, states):
580+
return self.cell.call(inputs, states, **kwargs)
581+
532582
last_output, outputs, states = K.rnn(step,
533583
inputs,
534584
initial_state,
585+
constants=constants,
535586
go_backwards=self.go_backwards,
536587
mask=mask,
537588
unroll=self.unroll,
@@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
560611
else:
561612
return output
562613

614+
def _standardize_args(self, inputs, initial_state, constants):
615+
"""Brings the arguments of `__call__` that can contain input tensors to
616+
standard format.
617+
618+
When running a model loaded from file, the input tensors
619+
`initial_state` and `constants` can be passed to `RNN.__call__` as part
620+
of `inputs` instead of by the dedicated keyword arguments. This method
621+
makes sure the arguments are separated and that `initial_state` and
622+
`constants` are lists of tensors (or None).
623+
624+
# Arguments
625+
inputs: tensor or list/tuple of tensors
626+
initial_state: tensor or list of tensors or None
627+
constants: tensor or list of tensors or None
628+
629+
# Returns
630+
inputs: tensor
631+
initial_state: list of tensors or None
632+
constants: list of tensors or None
633+
"""
634+
if isinstance(inputs, list):
635+
assert initial_state is None and constants is None
636+
if self._num_constants is not None:
637+
constants = inputs[-self._num_constants:]
638+
inputs = inputs[:-self._num_constants]
639+
if len(inputs) > 1:
640+
initial_state = inputs[1:]
641+
inputs = inputs[0]
642+
643+
def to_list_or_none(x):
644+
if x is None or isinstance(x, list):
645+
return x
646+
if isinstance(x, tuple):
647+
return list(x)
648+
return [x]
649+
650+
initial_state = to_list_or_none(initial_state)
651+
constants = to_list_or_none(constants)
652+
653+
return inputs, initial_state, constants
654+
563655
def reset_states(self, states=None):
564656
if not self.stateful:
565657
raise AttributeError('Layer must be stateful.')
@@ -618,6 +710,9 @@ def get_config(self):
618710
'go_backwards': self.go_backwards,
619711
'stateful': self.stateful,
620712
'unroll': self.unroll}
713+
if self._num_constants is not None:
714+
config['num_constants'] = self._num_constants
715+
621716
cell_config = self.cell.get_config()
622717
config['cell'] = {'class_name': self.cell.__class__.__name__,
623718
'config': cell_config}
@@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None):
629724
from . import deserialize as deserialize_layer
630725
cell = deserialize_layer(config.pop('cell'),
631726
custom_objects=custom_objects)
632-
return cls(cell, **config)
727+
num_constants = config.pop('num_constants', None)
728+
layer = cls(cell, **config)
729+
layer._num_constants = num_constants
730+
return layer
633731

634732
@property
635733
def trainable_weights(self):

0 commit comments

Comments
 (0)