diff --git a/CHANGELOG.md b/CHANGELOG.md index e393a9018..5b9ffdf79 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -80,12 +80,14 @@ To release a new version, please update the changelog as followed: ### Fixed - RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033) +- BN updates: fix BatchNorm1d for 2D data, refactored (#PR 1040) ### Removed ### Security ### Contributors +- @ChrisWu1997: #1040 ## [2.2.1] diff --git a/tensorlayer/files/utils.py b/tensorlayer/files/utils.py index 242590c04..1a3a429f0 100644 --- a/tensorlayer/files/utils.py +++ b/tensorlayer/files/utils.py @@ -2666,6 +2666,10 @@ def _load_weights_from_hdf5_group(f, layers, skip=False): elif isinstance(layer, tl.layers.Layer): weight_names = [n.decode('utf8') for n in g.attrs['weight_names']] for iid, w_name in enumerate(weight_names): + # FIXME : this is only for compatibility + if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1: + assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze()) + continue assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name])) else: raise Exception("Only layer or model can be saved into hdf5.") diff --git a/tensorlayer/layers/normalization.py b/tensorlayer/layers/normalization.py index 226795981..a609f5671 100644 --- a/tensorlayer/layers/normalization.py +++ b/tensorlayer/layers/normalization.py @@ -108,6 +108,19 @@ def _bias_add(x, b, data_format): def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None): """Data Format aware version of tf.nn.batch_normalization.""" + if data_format == 'channels_last': + mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1]) + variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1]) + offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1]) + scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1]) + elif data_format == 'channels_first': + mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2)) + variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2)) + offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2)) + scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2)) + else: + raise ValueError('invalid data_format: %s' % data_format) + with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]): inv = math_ops.rsqrt(variance + variance_epsilon) if scale is not None: @@ -204,13 +217,10 @@ def __init__( self.moving_var_init = moving_var_init self.num_features = num_features + self.channel_axis = -1 if data_format == 'channels_last' else 1 + self.axes = None + if num_features is not None: - if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self, - BatchNorm3d): - raise ValueError( - "Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm " - "if you want to specify 'num_features'." - ) self.build(None) self._built = True @@ -233,21 +243,23 @@ def __repr__(self): def _get_param_shape(self, inputs_shape): if self.data_format == 'channels_last': - axis = len(inputs_shape) - 1 + axis = -1 elif self.data_format == 'channels_first': axis = 1 else: raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) channels = inputs_shape[axis] - params_shape = [1] * len(inputs_shape) - params_shape[axis] = channels + params_shape = [channels] - axes = [i for i in range(len(inputs_shape)) if i != axis] - return params_shape, axes + return params_shape + + def _check_input_shape(self, inputs): + if inputs.ndim <= 1: + raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim)) def build(self, inputs_shape): - params_shape, self.axes = self._get_param_shape(inputs_shape) + params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape) self.beta, self.gamma = None, None if self.beta_init: @@ -264,7 +276,12 @@ def build(self, inputs_shape): ) def forward(self, inputs): - mean, var = tf.nn.moments(inputs, self.axes, keepdims=True) + self._check_input_shape(inputs) + + if self.axes is None: + self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis] + + mean, var = tf.nn.moments(inputs, self.axes, keepdims=False) if self.is_train: # update moving_mean and moving_var self.moving_mean = moving_averages.assign_moving_average( @@ -282,8 +299,8 @@ def forward(self, inputs): class BatchNorm1d(BatchNorm): - """The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D - inputs with additional channel dimension), of shape (N, L, C) or (N, C, L). + """The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D + inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L). See more details in :class:`BatchNorm`. Examples @@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 2 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 3 - params_shape[axis] = channels - - axes = [i for i in range(3) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 2 and inputs.ndim != 3: + raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim)) class BatchNorm2d(BatchNorm): @@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 3 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 4 - params_shape[axis] = channels - - axes = [i for i in range(4) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 4: + raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim)) class BatchNorm3d(BatchNorm): @@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm): """ - def _get_param_shape(self, inputs_shape): - if self.data_format == 'channels_last': - axis = 4 - elif self.data_format == 'channels_first': - axis = 1 - else: - raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first')) - - if self.num_features is None: - channels = inputs_shape[axis] - else: - channels = self.num_features - params_shape = [1] * 5 - params_shape[axis] = channels - - axes = [i for i in range(5) if i != axis] - return params_shape, axes + def _check_input_shape(self, inputs): + if inputs.ndim != 5: + raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim)) class InstanceNorm(Layer): diff --git a/tensorlayer/models/mobilenetv1.py b/tensorlayer/models/mobilenetv1.py index 4908b3d89..82ea7be46 100644 --- a/tensorlayer/models/mobilenetv1.py +++ b/tensorlayer/models/mobilenetv1.py @@ -43,9 +43,9 @@ def restore_params(network, path='models'): expected_bytes=25600116 ) # ls -al params = load_npz(name=os.path.join(path, 'mobilenet.npz')) - for idx, net_weight in enumerate(network.all_weights): - if 'batchnorm' in net_weight.name: - params[idx] = params[idx].reshape(1, 1, 1, -1) + # for idx, net_weight in enumerate(network.all_weights): + # if 'batchnorm' in net_weight.name: + # params[idx] = params[idx].reshape(1, 1, 1, -1) assign_weights(params[:len(network.all_weights)], network) del params diff --git a/tensorlayer/models/resnet.py b/tensorlayer/models/resnet.py index 87bdc5641..7df069468 100644 --- a/tensorlayer/models/resnet.py +++ b/tensorlayer/models/resnet.py @@ -194,8 +194,8 @@ def restore_params(network, path='models'): continue w_names = list(f[layer.name]) params = [f[layer.name][n][:] for n in w_names] - if 'bn' in layer.name: - params = [x.reshape(1, 1, 1, -1) for x in params] + # if 'bn' in layer.name: + # params = [x.reshape(1, 1, 1, -1) for x in params] assign_weights(params, layer) del params diff --git a/tests/layers/test_layers_normalization.py b/tests/layers/test_layers_normalization.py index a25e47f76..c223f61ed 100644 --- a/tests/layers/test_layers_normalization.py +++ b/tests/layers/test_layers_normalization.py @@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase): @classmethod def setUpClass(cls): + x_0_input_shape = [None, 10] x_1_input_shape = [None, 100, 1] x_2_input_shape = [None, 100, 100, 3] x_3_input_shape = [None, 100, 100, 100, 3] batchsize = 2 + cls.x0 = tf.random.normal([batchsize] + x_0_input_shape[1:]) cls.x1 = tf.random.normal([batchsize] + x_1_input_shape[1:]) cls.x2 = tf.random.normal([batchsize] + x_2_input_shape[1:]) cls.x3 = tf.random.normal([batchsize] + x_3_input_shape[1:]) @@ -36,16 +38,58 @@ def setUpClass(cls): ni_2 = Input(x_2_input_shape, name='test_ni2') nn_2 = Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), name='test_conv2d')(ni_2) - n2_b = BatchNorm2d(name='test_bn2d')(nn_2) + n2_b = BatchNorm(name='test_bn2d')(nn_2) cls.n2_b = n2_b cls.base_2d = Model(inputs=ni_2, outputs=n2_b, name='test_base_2d') ni_3 = Input(x_3_input_shape, name='test_ni2') nn_3 = Conv3d(n_filter=32, filter_size=(3, 3, 3), strides=(2, 2, 2), name='test_conv3d')(ni_3) - n3_b = BatchNorm3d(name='test_bn3d')(nn_3) + n3_b = BatchNorm(name='test_bn3d')(nn_3) cls.n3_b = n3_b cls.base_3d = Model(inputs=ni_3, outputs=n3_b, name='test_base_3d') + class bn_0d_model(Model): + + def __init__(self): + super(bn_0d_model, self).__init__() + self.fc = Dense(32, in_channels=10) + self.bn = BatchNorm(num_features=32, name='test_bn1d') + + def forward(self, x): + x = self.bn(self.fc(x)) + return x + + dynamic_base = bn_0d_model() + cls.n0_b = dynamic_base(cls.x0, is_train=True) + + ## 0D ======================================================================== + + nin_0 = Input(x_0_input_shape, name='test_in1') + + n0 = Dense(32)(nin_0) + n0 = BatchNorm1d(name='test_bn0d')(n0) + + cls.n0 = n0 + + cls.static_0d = Model(inputs=nin_0, outputs=n0) + + class bn_0d_model(Model): + + def __init__(self): + super(bn_0d_model, self).__init__(name='test_bn_0d_model') + self.fc = Dense(32, in_channels=10) + self.bn = BatchNorm1d(num_features=32, name='test_bn1d') + + def forward(self, x): + x = self.bn(self.fc(x)) + return x + + cls.dynamic_0d = bn_0d_model() + + print("Printing BatchNorm0d") + print(cls.static_0d) + print(cls.dynamic_0d) + ## 1D ======================================================================== nin_1 = Input(x_1_input_shape, name='test_in1') @@ -147,6 +191,14 @@ def test_BatchNorm(self): self.assertEqual(self.n3_b.shape[1:], (50, 50, 50, 32)) out = self.base_3d(self.x3, is_train=True) + self.assertEqual(self.n0_b.shape[1:], (32)) + print("test_BatchNorm OK") + + def test_BatchNorm0d(self): + self.assertEqual(self.n0.shape[1:], (32)) + out = self.static_0d(self.x0, is_train=True) + out = self.dynamic_0d(self.x0, is_train=True) + def test_BatchNorm1d(self): self.assertEqual(self.n1.shape[1:], (50, 32)) out = self.static_1d(self.x1, is_train=True) @@ -189,6 +241,26 @@ def test_exception(self): self.assertIsInstance(e, ValueError) print(e) + def test_input_shape(self): + try: + bn = BatchNorm1d(num_features=32) + out = bn(self.x2) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + try: + bn = BatchNorm2d(num_features=32) + out = bn(self.x3) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + try: + bn = BatchNorm3d(num_features=32) + out = bn(self.x1) + except Exception as e: + self.assertIsInstance(e, ValueError) + print(e) + if __name__ == '__main__':