Skip to content

Commit 776b751

Browse files
Squadrickseanpmorgan
authored andcommitted
FIX: Invariant input_spec for WeightNormalization (#687)
* Make the first dimension `None` to support invariant batch size. * Add test case to check compatibility of WeightNormalization with TimeDistributed.
1 parent 895d11d commit 776b751

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

tensorflow_addons/layers/wrappers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ def __init__(self, layer, data_init=True, **kwargs):
6262
def build(self, input_shape):
6363
"""Build `Layer`"""
6464
input_shape = tf.TensorShape(input_shape).as_list()
65-
self.input_spec = tf.keras.layers.InputSpec(shape=input_shape)
65+
self.input_spec = tf.keras.layers.InputSpec(
66+
shape=[None] + input_shape[1:])
6667

6768
if not self.layer.built:
6869
self.layer.build(input_shape)

tensorflow_addons/layers/wrappers_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,14 @@ def test_weightnorm_non_kernel_layer(self):
7373
wn_wrapper = wrappers.WeightNormalization(non_kernel_layer)
7474
wn_wrapper(images)
7575

76+
def test_weightnorm_with_time_dist(self):
77+
batch_shape = (32, 16, 64, 64, 3)
78+
inputs = tf.keras.layers.Input(batch_shape=batch_shape)
79+
a = tf.keras.layers.Conv2D(3, 5)
80+
b = wrappers.WeightNormalization(a)
81+
out = tf.keras.layers.TimeDistributed(b)(inputs)
82+
model = tf.keras.Model(inputs, out)
83+
7684

7785
if __name__ == "__main__":
7886
tf.test.main()

0 commit comments

Comments
 (0)