Skip to content

Commit 5b62434

Browse files
gabrieldemarmiesseFrédéric Branchaud-Charron
authored andcommitted
Breaking down the attention API PR: part 2 (#11140)
### Summary This refactoring will allow the simplification of some code in #8296 ### Related Issues ### PR Overview - [ ] This PR requires new unit tests [y/n] (make sure tests are included) - [ ] This PR requires to update the documentation [y/n] (make sure the docs are up-to-date) - [x] This PR is backwards compatible [y/n] - [ ] This PR changes the current API [y/n] (all API changes need to be approved by fchollet)
1 parent 842d360 commit 5b62434

File tree

8 files changed

+27
-54
lines changed

8 files changed

+27
-54
lines changed

keras/engine/network.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -139,14 +139,8 @@ def _base_init(self, name=None):
139139
def _init_graph_network(self, inputs, outputs, name=None):
140140
self._uses_inputs_arg = True
141141
# Normalize and set self.inputs, self.outputs.
142-
if isinstance(inputs, (list, tuple)):
143-
self.inputs = list(inputs) # Tensor or list of tensors.
144-
else:
145-
self.inputs = [inputs]
146-
if isinstance(outputs, (list, tuple)):
147-
self.outputs = list(outputs)
148-
else:
149-
self.outputs = [outputs]
142+
self.inputs = to_list(inputs, allow_tuple=True)
143+
self.outputs = to_list(outputs, allow_tuple=True)
150144

151145
# User-provided argument validation.
152146
# Check for redundancy in inputs.

keras/engine/training.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -596,10 +596,7 @@ def _set_inputs(self, inputs, outputs=None, training=None):
596596
self._feed_inputs = []
597597
self._feed_input_names = []
598598
self._feed_input_shapes = []
599-
if isinstance(inputs, (list, tuple)):
600-
inputs = list(inputs)
601-
else:
602-
inputs = [inputs]
599+
inputs = to_list(inputs, allow_tuple=True)
603600

604601
for i, v in enumerate(inputs):
605602
name = 'input_%d' % (i + 1)
@@ -633,10 +630,7 @@ def _set_inputs(self, inputs, outputs=None, training=None):
633630
outputs = self.call(unpack_singleton(self.inputs), training=training)
634631
else:
635632
outputs = self.call(unpack_singleton(self.inputs))
636-
if isinstance(outputs, (list, tuple)):
637-
outputs = list(outputs)
638-
else:
639-
outputs = [outputs]
633+
outputs = to_list(outputs, allow_tuple=True)
640634
self.outputs = outputs
641635
self.output_names = [
642636
'output_%d' % (i + 1) for i in range(len(self.outputs))]
@@ -704,10 +698,7 @@ def _standardize_user_data(self, x,
704698
'You passed: y=' + str(y))
705699
# Typecheck that all inputs are *either* value *or* symbolic.
706700
if y is not None:
707-
if isinstance(y, (list, tuple)):
708-
all_inputs += list(y)
709-
else:
710-
all_inputs.append(y)
701+
all_inputs += to_list(y, allow_tuple=True)
711702
if any(K.is_tensor(v) for v in all_inputs):
712703
if not all(K.is_tensor(v) for v in all_inputs):
713704
raise ValueError('Do not pass inputs that mix Numpy '
@@ -716,8 +707,7 @@ def _standardize_user_data(self, x,
716707
'; y=' + str(y))
717708

718709
# Handle target tensors if any passed.
719-
if not isinstance(y, (list, tuple)):
720-
y = [y]
710+
y = to_list(y, allow_tuple=True)
721711
target_tensors = [v for v in y if K.is_tensor(v)]
722712
if not target_tensors:
723713
target_tensors = None

keras/layers/advanced_activations.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from ..engine.base_layer import InputSpec
1414
from .. import backend as K
1515
from ..legacy import interfaces
16+
from ..utils.generic_utils import to_list
1617

1718

1819
class LeakyReLU(Layer):
@@ -100,10 +101,8 @@ def __init__(self, alpha_initializer='zeros',
100101
self.alpha_constraint = constraints.get(alpha_constraint)
101102
if shared_axes is None:
102103
self.shared_axes = None
103-
elif not isinstance(shared_axes, (list, tuple)):
104-
self.shared_axes = [shared_axes]
105104
else:
106-
self.shared_axes = list(shared_axes)
105+
self.shared_axes = to_list(shared_axes, allow_tuple=True)
107106

108107
def build(self, input_shape):
109108
param_shape = list(input_shape[1:])

keras/layers/convolutional_recurrent.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ..legacy.layers import Recurrent, ConvRecurrent2D
2222
from .recurrent import RNN
2323
from ..utils.generic_utils import has_arg
24+
from ..utils.generic_utils import to_list
2425
from ..utils.generic_utils import transpose_shape
2526

2627

@@ -387,10 +388,7 @@ def step(inputs, states):
387388
output._uses_learning_phase = True
388389

389390
if self.return_state:
390-
if not isinstance(states, (list, tuple)):
391-
states = [states]
392-
else:
393-
states = list(states)
391+
states = to_list(states, allow_tuple=True)
394392
return [output] + states
395393
else:
396394
return output
@@ -443,8 +441,7 @@ def get_tuple_shape(nb_channels):
443441
K.set_value(self.states[0],
444442
np.zeros(get_tuple_shape(self.cell.state_size)))
445443
else:
446-
if not isinstance(states, (list, tuple)):
447-
states = [states]
444+
states = to_list(states, allow_tuple=True)
448445
if len(states) != len(self.states):
449446
raise ValueError('Layer ' + self.name + ' expects ' +
450447
str(len(self.states)) + ' states, '

keras/layers/embeddings.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .. import constraints
1111
from ..engine.base_layer import Layer
1212
from ..legacy import interfaces
13+
from ..utils.generic_utils import to_list
1314

1415

1516
class Embedding(Layer):
@@ -117,10 +118,7 @@ def compute_output_shape(self, input_shape):
117118
return input_shape + (self.output_dim,)
118119
else:
119120
# input_length can be tuple if input is 3D or higher
120-
if isinstance(self.input_length, (list, tuple)):
121-
in_lens = list(self.input_length)
122-
else:
123-
in_lens = [self.input_length]
121+
in_lens = to_list(self.input_length, allow_tuple=True)
124122
if len(in_lens) != len(input_shape) - 1:
125123
raise ValueError('"input_length" is %s, but received input has shape %s' %
126124
(str(self.input_length), str(input_shape)))

keras/layers/recurrent.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from ..engine.base_layer import Layer
1717
from ..engine.base_layer import InputSpec
1818
from ..utils.generic_utils import has_arg
19+
from ..utils.generic_utils import to_list
1920

2021
# Legacy support.
2122
from ..legacy.layers import Recurrent
@@ -664,10 +665,7 @@ def step(inputs, states):
664665
state._uses_learning_phase = True
665666

666667
if self.return_state:
667-
if not isinstance(states, (list, tuple)):
668-
states = [states]
669-
else:
670-
states = list(states)
668+
states = to_list(states, allow_tuple=True)
671669
return [output] + states
672670
else:
673671
return output
@@ -702,8 +700,7 @@ def reset_states(self, states=None):
702700
K.set_value(self.states[0],
703701
np.zeros((batch_size, self.cell.state_size)))
704702
else:
705-
if not isinstance(states, (list, tuple)):
706-
states = [states]
703+
states = to_list(states, allow_tuple=True)
707704
if len(states) != len(self.states):
708705
raise ValueError('Layer ' + self.name + ' expects ' +
709706
str(len(self.states)) + ' states, '

keras/legacy/layers.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -508,8 +508,7 @@ def __call__(self, inputs, initial_state=None, **kwargs):
508508
if initial_state is None:
509509
return super(Recurrent, self).__call__(inputs, **kwargs)
510510

511-
if not isinstance(initial_state, (list, tuple)):
512-
initial_state = [initial_state]
511+
initial_state = to_list(initial_state, allow_tuple=True)
513512

514513
is_keras_tensor = hasattr(initial_state[0], '_keras_history')
515514
for tensor in initial_state:
@@ -602,10 +601,7 @@ def call(self, inputs, mask=None, training=None, initial_state=None):
602601
output = last_output
603602

604603
if self.return_state:
605-
if not isinstance(states, (list, tuple)):
606-
states = [states]
607-
else:
608-
states = list(states)
604+
states = to_list(states, allow_tuple=True)
609605
return [output] + states
610606
else:
611607
return output
@@ -633,8 +629,7 @@ def reset_states(self, states=None):
633629
for state in self.states:
634630
K.set_value(state, np.zeros((batch_size, self.units)))
635631
else:
636-
if not isinstance(states, (list, tuple)):
637-
states = [states]
632+
states = to_list(states, allow_tuple=True)
638633
if len(states) != len(self.states):
639634
raise ValueError('Layer ' + self.name + ' expects ' +
640635
str(len(self.states)) + ' states, '

keras/utils/generic_utils.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,20 +444,26 @@ def add(self, n, values=None):
444444
self.update(self._seen_so_far + n, values)
445445

446446

447-
def to_list(x):
447+
def to_list(x, allow_tuple=False):
448448
"""Normalizes a list/tensor into a list.
449449
450450
If a tensor is passed, we return
451451
a list of size 1 containing the tensor.
452452
453453
# Arguments
454454
x: target object to be normalized.
455+
allow_tuple: If False and x is a tuple,
456+
it will be converted into a list
457+
with a single element (the tuple).
458+
Else converts the tuple to a list.
455459
456460
# Returns
457461
A list.
458462
"""
459463
if isinstance(x, list):
460464
return x
465+
if allow_tuple and isinstance(x, tuple):
466+
return list(x)
461467
return [x]
462468

463469

@@ -483,10 +489,7 @@ def object_list_uid(object_list):
483489

484490

485491
def is_all_none(iterable_or_element):
486-
if not isinstance(iterable_or_element, (list, tuple)):
487-
iterable = [iterable_or_element]
488-
else:
489-
iterable = iterable_or_element
492+
iterable = to_list(iterable_or_element, allow_tuple=True)
490493
for element in iterable:
491494
if element is not None:
492495
return False

0 commit comments

Comments
 (0)