1
1
# -*- coding: utf-8 -*-
2
2
from __future__ import absolute_import
3
3
import numpy as np
4
- import functools
5
4
import warnings
6
5
7
6
from .. import backend as K
@@ -200,7 +199,9 @@ class RNN(Layer):
200
199
# Arguments
201
200
cell: A RNN cell instance. A RNN cell is a class that has:
202
201
- a `call(input_at_t, states_at_t)` method, returning
203
- `(output_at_t, states_at_t_plus_1)`.
202
+ `(output_at_t, states_at_t_plus_1)`. The call method of the
203
+ cell can also take the optional argument `constants`, see
204
+ section "Note on passing external constants" below.
204
205
- a `state_size` attribute. This can be a single integer
205
206
(single state) in which case it is
206
207
the size of the recurrent state
@@ -292,6 +293,14 @@ class RNN(Layer):
292
293
`states` should be a numpy array or list of numpy arrays representing
293
294
the initial state of the RNN layer.
294
295
296
+ # Note on passing external constants to RNNs
297
+ You can pass "external" constants to the cell using the `constants`
298
+ keyword argument of `RNN.__call__` (as well as `RNN.call`) method. This
299
+ requires that the `cell.call` method accepts the same keyword argument
300
+ `constants`. Such constants can be used to condition the cell
301
+ transformation on additional static inputs (not changing over time),
302
+ a.k.a. an attention mechanism.
303
+
295
304
# Examples
296
305
297
306
```python
@@ -363,12 +372,10 @@ def __init__(self, cell,
363
372
364
373
self .supports_masking = True
365
374
self .input_spec = [InputSpec (ndim = 3 )]
366
- if hasattr (self .cell .state_size , '__len__' ):
367
- self .state_spec = [InputSpec (shape = (None , dim ))
368
- for dim in self .cell .state_size ]
369
- else :
370
- self .state_spec = InputSpec (shape = (None , self .cell .state_size ))
375
+ self .state_spec = None
371
376
self ._states = None
377
+ self .constants_spec = None
378
+ self ._num_constants = None
372
379
373
380
@property
374
381
def states (self ):
@@ -415,19 +422,46 @@ def compute_mask(self, inputs, mask):
415
422
return output_mask
416
423
417
424
def build (self , input_shape ):
425
+ # Note input_shape will be list of shapes of initial states and
426
+ # constants if these are passed in __call__.
427
+ if self ._num_constants is not None :
428
+ constants_shape = input_shape [- self ._num_constants :]
429
+ else :
430
+ constants_shape = None
431
+
418
432
if isinstance (input_shape , list ):
419
433
input_shape = input_shape [0 ]
420
434
421
435
batch_size = input_shape [0 ] if self .stateful else None
422
436
input_dim = input_shape [- 1 ]
423
437
self .input_spec [0 ] = InputSpec (shape = (batch_size , None , input_dim ))
424
438
425
- if self .stateful :
426
- self .reset_states ()
427
-
439
+ # allow cell (if layer) to build before we set or validate state_spec
428
440
if isinstance (self .cell , Layer ):
429
441
step_input_shape = (input_shape [0 ],) + input_shape [2 :]
430
- self .cell .build (step_input_shape )
442
+ if constants_shape is not None :
443
+ self .cell .build ([step_input_shape ] + constants_shape )
444
+ else :
445
+ self .cell .build (step_input_shape )
446
+
447
+ # set or validate state_spec
448
+ if hasattr (self .cell .state_size , '__len__' ):
449
+ state_size = list (self .cell .state_size )
450
+ else :
451
+ state_size = [self .cell .state_size ]
452
+
453
+ if self .state_spec is not None :
454
+ # initial_state was passed in call, check compatibility
455
+ if not [spec .shape [- 1 ] for spec in self .state_spec ] == state_size :
456
+ raise ValueError (
457
+ 'an initial_state was passed that is not compatible with'
458
+ ' cell.state_size, state_spec: {}, cell.state_size:'
459
+ ' {}' .format (self .state_spec , self .cell .state_size ))
460
+ else :
461
+ self .state_spec = [InputSpec (shape = (None , dim ))
462
+ for dim in state_size ]
463
+ if self .stateful :
464
+ self .reset_states ()
431
465
432
466
def get_initial_state (self , inputs ):
433
467
# build an all-zero tensor of shape (samples, output_dim)
@@ -440,62 +474,65 @@ def get_initial_state(self, inputs):
440
474
else :
441
475
return [K .tile (initial_state , [1 , self .cell .state_size ])]
442
476
443
- def __call__ (self , inputs , initial_state = None , ** kwargs ):
444
- # If there are multiple inputs, then
445
- # they should be the main input and `initial_state`
446
- # e.g. when loading model from file
447
- if isinstance (inputs , (list , tuple )) and len (inputs ) > 1 and initial_state is None :
448
- initial_state = inputs [1 :]
449
- inputs = inputs [0 ]
477
+ def __call__ (self , inputs , initial_state = None , constants = None , ** kwargs ):
478
+ inputs , initial_state , constants = self ._standardize_args (
479
+ inputs , initial_state , constants )
450
480
451
- # If `initial_state` is specified,
452
- # and if it a Keras tensor,
453
- # then add it to the inputs and temporarily
454
- # modify the input spec to include the state.
455
- if initial_state is None :
481
+ if initial_state is None and constants is None :
456
482
return super (RNN , self ).__call__ (inputs , ** kwargs )
457
483
458
- if not isinstance (initial_state , (list , tuple )):
459
- initial_state = [initial_state ]
484
+ # If any of `initial_state` or `constants` are specified and are Keras
485
+ # tensors, then add them to the inputs and temporarily modify the
486
+ # input_spec to include them.
460
487
461
- is_keras_tensor = hasattr (initial_state [0 ], '_keras_history' )
462
- for tensor in initial_state :
488
+ additional_inputs = []
489
+ additional_specs = []
490
+ if initial_state is not None :
491
+ kwargs ['initial_state' ] = initial_state
492
+ additional_inputs += initial_state
493
+ self .state_spec = [InputSpec (shape = K .int_shape (state ))
494
+ for state in initial_state ]
495
+ additional_specs += self .state_spec
496
+ if constants is not None :
497
+ kwargs ['constants' ] = constants
498
+ additional_inputs += constants
499
+ self .constants_spec = [InputSpec (shape = K .int_shape (constant ))
500
+ for constant in constants ]
501
+ self ._num_constants = len (constants )
502
+ additional_specs += self .constants_spec
503
+ # at this point additional_inputs cannot be empty
504
+ is_keras_tensor = hasattr (additional_inputs [0 ], '_keras_history' )
505
+ for tensor in additional_inputs :
463
506
if hasattr (tensor , '_keras_history' ) != is_keras_tensor :
464
- raise ValueError ('The initial state of an RNN layer cannot be '
465
- ' specified with a mix of Keras tensors and '
466
- ' non-Keras tensors' )
507
+ raise ValueError ('The initial state or constants of an RNN'
508
+ ' layer cannot be specified with a mix of'
509
+ ' Keras tensors and non-Keras tensors' )
467
510
468
511
if is_keras_tensor :
469
- # Compute the full input spec, including state
470
- input_spec = self .input_spec
471
- state_spec = self .state_spec
472
- if not isinstance (input_spec , list ):
473
- input_spec = [input_spec ]
474
- if not isinstance (state_spec , list ):
475
- state_spec = [state_spec ]
476
- self .input_spec = input_spec + state_spec
477
-
478
- # Compute the full inputs, including state
479
- inputs = [inputs ] + list (initial_state )
480
-
481
- # Perform the call
482
- output = super (RNN , self ).__call__ (inputs , ** kwargs )
483
-
484
- # Restore original input spec
485
- self .input_spec = input_spec
512
+ # Compute the full input spec, including state and constants
513
+ full_input = [inputs ] + additional_inputs
514
+ full_input_spec = self .input_spec + additional_specs
515
+ # Perform the call with temporarily replaced input_spec
516
+ original_input_spec = self .input_spec
517
+ self .input_spec = full_input_spec
518
+ output = super (RNN , self ).__call__ (full_input , ** kwargs )
519
+ self .input_spec = original_input_spec
486
520
return output
487
521
else :
488
- kwargs ['initial_state' ] = initial_state
489
522
return super (RNN , self ).__call__ (inputs , ** kwargs )
490
523
491
- def call (self , inputs , mask = None , training = None , initial_state = None ):
524
+ def call (self ,
525
+ inputs ,
526
+ mask = None ,
527
+ training = None ,
528
+ initial_state = None ,
529
+ constants = None ):
492
530
# input shape: `(samples, time (padded with zeros), input_dim)`
493
531
# note that the .build() method of subclasses MUST define
494
532
# self.input_spec and self.state_spec with complete input shapes.
495
533
if isinstance (inputs , list ):
496
- initial_state = inputs [1 :]
497
534
inputs = inputs [0 ]
498
- elif initial_state is not None :
535
+ if initial_state is not None :
499
536
pass
500
537
elif self .stateful :
501
538
initial_state = self .states
@@ -525,13 +562,27 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
525
562
'the time dimension by passing a `shape` '
526
563
'or `batch_shape` argument to your Input layer.' )
527
564
565
+ kwargs = {}
528
566
if has_arg (self .cell .call , 'training' ):
529
- step = functools .partial (self .cell .call , training = training )
567
+ kwargs ['training' ] = training
568
+
569
+ if constants :
570
+ if not has_arg (self .cell .call , 'constants' ):
571
+ raise ValueError ('RNN cell does not support constants' )
572
+
573
+ def step (inputs , states ):
574
+ constants = states [- self ._num_constants :]
575
+ states = states [:- self ._num_constants ]
576
+ return self .cell .call (inputs , states , constants = constants ,
577
+ ** kwargs )
530
578
else :
531
- step = self .cell .call
579
+ def step (inputs , states ):
580
+ return self .cell .call (inputs , states , ** kwargs )
581
+
532
582
last_output , outputs , states = K .rnn (step ,
533
583
inputs ,
534
584
initial_state ,
585
+ constants = constants ,
535
586
go_backwards = self .go_backwards ,
536
587
mask = mask ,
537
588
unroll = self .unroll ,
@@ -560,6 +611,47 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
560
611
else :
561
612
return output
562
613
614
+ def _standardize_args (self , inputs , initial_state , constants ):
615
+ """Brings the arguments of `__call__` that can contain input tensors to
616
+ standard format.
617
+
618
+ When running a model loaded from file, the input tensors
619
+ `initial_state` and `constants` can be passed to `RNN.__call__` as part
620
+ of `inputs` instead of by the dedicated keyword arguments. This method
621
+ makes sure the arguments are separated and that `initial_state` and
622
+ `constants` are lists of tensors (or None).
623
+
624
+ # Arguments
625
+ inputs: tensor or list/tuple of tensors
626
+ initial_state: tensor or list of tensors or None
627
+ constants: tensor or list of tensors or None
628
+
629
+ # Returns
630
+ inputs: tensor
631
+ initial_state: list of tensors or None
632
+ constants: list of tensors or None
633
+ """
634
+ if isinstance (inputs , list ):
635
+ assert initial_state is None and constants is None
636
+ if self ._num_constants is not None :
637
+ constants = inputs [- self ._num_constants :]
638
+ inputs = inputs [:- self ._num_constants ]
639
+ if len (inputs ) > 1 :
640
+ initial_state = inputs [1 :]
641
+ inputs = inputs [0 ]
642
+
643
+ def to_list_or_none (x ):
644
+ if x is None or isinstance (x , list ):
645
+ return x
646
+ if isinstance (x , tuple ):
647
+ return list (x )
648
+ return [x ]
649
+
650
+ initial_state = to_list_or_none (initial_state )
651
+ constants = to_list_or_none (constants )
652
+
653
+ return inputs , initial_state , constants
654
+
563
655
def reset_states (self , states = None ):
564
656
if not self .stateful :
565
657
raise AttributeError ('Layer must be stateful.' )
@@ -618,6 +710,9 @@ def get_config(self):
618
710
'go_backwards' : self .go_backwards ,
619
711
'stateful' : self .stateful ,
620
712
'unroll' : self .unroll }
713
+ if self ._num_constants is not None :
714
+ config ['num_constants' ] = self ._num_constants
715
+
621
716
cell_config = self .cell .get_config ()
622
717
config ['cell' ] = {'class_name' : self .cell .__class__ .__name__ ,
623
718
'config' : cell_config }
@@ -629,7 +724,10 @@ def from_config(cls, config, custom_objects=None):
629
724
from . import deserialize as deserialize_layer
630
725
cell = deserialize_layer (config .pop ('cell' ),
631
726
custom_objects = custom_objects )
632
- return cls (cell , ** config )
727
+ num_constants = config .pop ('num_constants' , None )
728
+ layer = cls (cell , ** config )
729
+ layer ._num_constants = num_constants
730
+ return layer
633
731
634
732
@property
635
733
def trainable_weights (self ):
0 commit comments