Skip to content

Refactor batch normalization #1040

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Sep 9, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,14 @@ To release a new version, please update the changelog as followed:

### Fixed
- RNN updates: remove warnings, fix if seq_len=0, unitest (#PR 1033)
- BN updates: fix BatchNorm1d for 2D data, refactored (#PR 1040)

### Removed

### Security

### Contributors
- @ChrisWu1997: #1040


## [2.2.1]
Expand Down
4 changes: 4 additions & 0 deletions tensorlayer/files/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2666,6 +2666,10 @@ def _load_weights_from_hdf5_group(f, layers, skip=False):
elif isinstance(layer, tl.layers.Layer):
weight_names = [n.decode('utf8') for n in g.attrs['weight_names']]
for iid, w_name in enumerate(weight_names):
# FIXME : this is only for compatibility
if isinstance(layer, tl.layers.BatchNorm) and np.asarray(g[w_name]).ndim > 1:
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]).squeeze())
continue
assign_tf_variable(layer.all_weights[iid], np.asarray(g[w_name]))
else:
raise Exception("Only layer or model can be saved into hdf5.")
Expand Down
107 changes: 41 additions & 66 deletions tensorlayer/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,19 @@ def _bias_add(x, b, data_format):

def batch_normalization(x, mean, variance, offset, scale, variance_epsilon, data_format, name=None):
"""Data Format aware version of tf.nn.batch_normalization."""
if data_format == 'channels_last':
mean = tf.reshape(mean, [1] * (len(x.shape) - 1) + [-1])
variance = tf.reshape(variance, [1] * (len(x.shape) - 1) + [-1])
offset = tf.reshape(offset, [1] * (len(x.shape) - 1) + [-1])
scale = tf.reshape(scale, [1] * (len(x.shape) - 1) + [-1])
elif data_format == 'channels_first':
mean = tf.reshape(mean, [1] + [-1] + [1] * (len(x.shape) - 2))
variance = tf.reshape(variance, [1] + [-1] + [1] * (len(x.shape) - 2))
offset = tf.reshape(offset, [1] + [-1] + [1] * (len(x.shape) - 2))
scale = tf.reshape(scale, [1] + [-1] + [1] * (len(x.shape) - 2))
else:
raise ValueError('invalid data_format: %s' % data_format)

with ops.name_scope(name, 'batchnorm', [x, mean, variance, scale, offset]):
inv = math_ops.rsqrt(variance + variance_epsilon)
if scale is not None:
Expand Down Expand Up @@ -204,13 +217,10 @@ def __init__(
self.moving_var_init = moving_var_init
self.num_features = num_features

self.channel_axis = -1 if data_format == 'channels_last' else 1
self.axes = None

if num_features is not None:
if not isinstance(self, BatchNorm1d) and not isinstance(self, BatchNorm2d) and not isinstance(self,
BatchNorm3d):
raise ValueError(
"Please use BatchNorm1d or BatchNorm2d or BatchNorm3d instead of BatchNorm "
"if you want to specify 'num_features'."
)
self.build(None)
self._built = True

Expand All @@ -233,21 +243,23 @@ def __repr__(self):

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = len(inputs_shape) - 1
axis = -1
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

channels = inputs_shape[axis]
params_shape = [1] * len(inputs_shape)
params_shape[axis] = channels
params_shape = [channels]

axes = [i for i in range(len(inputs_shape)) if i != axis]
return params_shape, axes
return params_shape

def _check_input_shape(self, inputs):
if inputs.ndim <= 1:
raise ValueError('expected input at least 2D, but got {}D input'.format(inputs.ndim))

def build(self, inputs_shape):
params_shape, self.axes = self._get_param_shape(inputs_shape)
params_shape = [self.num_features] if self.num_features is not None else self._get_param_shape(inputs_shape)

self.beta, self.gamma = None, None
if self.beta_init:
Expand All @@ -264,7 +276,12 @@ def build(self, inputs_shape):
)

def forward(self, inputs):
mean, var = tf.nn.moments(inputs, self.axes, keepdims=True)
self._check_input_shape(inputs)

if self.axes is None:
self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]

mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
if self.is_train:
# update moving_mean and moving_var
self.moving_mean = moving_averages.assign_moving_average(
Expand All @@ -282,8 +299,8 @@ def forward(self, inputs):


class BatchNorm1d(BatchNorm):
"""The :class:`BatchNorm1d` applies Batch Normalization over 3D input (a mini-batch of 1D
inputs with additional channel dimension), of shape (N, L, C) or (N, C, L).
"""The :class:`BatchNorm1d` applies Batch Normalization over 2D/3D input (a mini-batch of 1D
inputs (optional) with additional channel dimension), of shape (N, C) or (N, L, C) or (N, C, L).
See more details in :class:`BatchNorm`.

Examples
Expand All @@ -299,23 +316,9 @@ class BatchNorm1d(BatchNorm):

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 2
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 3
params_shape[axis] = channels

axes = [i for i in range(3) if i != axis]
return params_shape, axes
def _check_input_shape(self, inputs):
if inputs.ndim != 2 and inputs.ndim != 3:
raise ValueError('expected input to be 2D or 3D, but got {}D input'.format(inputs.ndim))


class BatchNorm2d(BatchNorm):
Expand All @@ -336,23 +339,9 @@ class BatchNorm2d(BatchNorm):

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 3
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 4
params_shape[axis] = channels

axes = [i for i in range(4) if i != axis]
return params_shape, axes
def _check_input_shape(self, inputs):
if inputs.ndim != 4:
raise ValueError('expected input to be 4D, but got {}D input'.format(inputs.ndim))


class BatchNorm3d(BatchNorm):
Expand All @@ -373,23 +362,9 @@ class BatchNorm3d(BatchNorm):

"""

def _get_param_shape(self, inputs_shape):
if self.data_format == 'channels_last':
axis = 4
elif self.data_format == 'channels_first':
axis = 1
else:
raise ValueError('data_format should be either %s or %s' % ('channels_last', 'channels_first'))

if self.num_features is None:
channels = inputs_shape[axis]
else:
channels = self.num_features
params_shape = [1] * 5
params_shape[axis] = channels

axes = [i for i in range(5) if i != axis]
return params_shape, axes
def _check_input_shape(self, inputs):
if inputs.ndim != 5:
raise ValueError('expected input to be 5D, but got {}D input'.format(inputs.ndim))


class InstanceNorm(Layer):
Expand Down
6 changes: 3 additions & 3 deletions tensorlayer/models/mobilenetv1.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def restore_params(network, path='models'):
expected_bytes=25600116
) # ls -al
params = load_npz(name=os.path.join(path, 'mobilenet.npz'))
for idx, net_weight in enumerate(network.all_weights):
if 'batchnorm' in net_weight.name:
params[idx] = params[idx].reshape(1, 1, 1, -1)
# for idx, net_weight in enumerate(network.all_weights):
# if 'batchnorm' in net_weight.name:
# params[idx] = params[idx].reshape(1, 1, 1, -1)
assign_weights(params[:len(network.all_weights)], network)
del params

Expand Down
4 changes: 2 additions & 2 deletions tensorlayer/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ def restore_params(network, path='models'):
continue
w_names = list(f[layer.name])
params = [f[layer.name][n][:] for n in w_names]
if 'bn' in layer.name:
params = [x.reshape(1, 1, 1, -1) for x in params]
# if 'bn' in layer.name:
# params = [x.reshape(1, 1, 1, -1) for x in params]
assign_weights(params, layer)
del params

Expand Down
76 changes: 74 additions & 2 deletions tests/layers/test_layers_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ class Laye_BatchNorm_Test(CustomTestCase):
@classmethod
def setUpClass(cls):

x_0_input_shape = [None, 10]
x_1_input_shape = [None, 100, 1]
x_2_input_shape = [None, 100, 100, 3]
x_3_input_shape = [None, 100, 100, 100, 3]
batchsize = 2

cls.x0 = tf.random.normal([batchsize] + x_0_input_shape[1:])
cls.x1 = tf.random.normal([batchsize] + x_1_input_shape[1:])
cls.x2 = tf.random.normal([batchsize] + x_2_input_shape[1:])
cls.x3 = tf.random.normal([batchsize] + x_3_input_shape[1:])
Expand All @@ -36,16 +38,58 @@ def setUpClass(cls):

ni_2 = Input(x_2_input_shape, name='test_ni2')
nn_2 = Conv2d(n_filter=32, filter_size=(3, 3), strides=(2, 2), name='test_conv2d')(ni_2)
n2_b = BatchNorm2d(name='test_bn2d')(nn_2)
n2_b = BatchNorm(name='test_bn2d')(nn_2)
cls.n2_b = n2_b
cls.base_2d = Model(inputs=ni_2, outputs=n2_b, name='test_base_2d')

ni_3 = Input(x_3_input_shape, name='test_ni2')
nn_3 = Conv3d(n_filter=32, filter_size=(3, 3, 3), strides=(2, 2, 2), name='test_conv3d')(ni_3)
n3_b = BatchNorm3d(name='test_bn3d')(nn_3)
n3_b = BatchNorm(name='test_bn3d')(nn_3)
cls.n3_b = n3_b
cls.base_3d = Model(inputs=ni_3, outputs=n3_b, name='test_base_3d')

class bn_0d_model(Model):

def __init__(self):
super(bn_0d_model, self).__init__()
self.fc = Dense(32, in_channels=10)
self.bn = BatchNorm(num_features=32, name='test_bn1d')

def forward(self, x):
x = self.bn(self.fc(x))
return x

dynamic_base = bn_0d_model()
cls.n0_b = dynamic_base(cls.x0, is_train=True)

## 0D ========================================================================

nin_0 = Input(x_0_input_shape, name='test_in1')

n0 = Dense(32)(nin_0)
n0 = BatchNorm1d(name='test_bn0d')(n0)

cls.n0 = n0

cls.static_0d = Model(inputs=nin_0, outputs=n0)

class bn_0d_model(Model):

def __init__(self):
super(bn_0d_model, self).__init__(name='test_bn_0d_model')
self.fc = Dense(32, in_channels=10)
self.bn = BatchNorm1d(num_features=32, name='test_bn1d')

def forward(self, x):
x = self.bn(self.fc(x))
return x

cls.dynamic_0d = bn_0d_model()

print("Printing BatchNorm0d")
print(cls.static_0d)
print(cls.dynamic_0d)

## 1D ========================================================================

nin_1 = Input(x_1_input_shape, name='test_in1')
Expand Down Expand Up @@ -147,6 +191,14 @@ def test_BatchNorm(self):
self.assertEqual(self.n3_b.shape[1:], (50, 50, 50, 32))
out = self.base_3d(self.x3, is_train=True)

self.assertEqual(self.n0_b.shape[1:], (32))
print("test_BatchNorm OK")

def test_BatchNorm0d(self):
self.assertEqual(self.n0.shape[1:], (32))
out = self.static_0d(self.x0, is_train=True)
out = self.dynamic_0d(self.x0, is_train=True)

def test_BatchNorm1d(self):
self.assertEqual(self.n1.shape[1:], (50, 32))
out = self.static_1d(self.x1, is_train=True)
Expand Down Expand Up @@ -189,6 +241,26 @@ def test_exception(self):
self.assertIsInstance(e, ValueError)
print(e)

def test_input_shape(self):
try:
bn = BatchNorm1d(num_features=32)
out = bn(self.x2)
except Exception as e:
self.assertIsInstance(e, ValueError)
print(e)
try:
bn = BatchNorm2d(num_features=32)
out = bn(self.x3)
except Exception as e:
self.assertIsInstance(e, ValueError)
print(e)
try:
bn = BatchNorm3d(num_features=32)
out = bn(self.x1)
except Exception as e:
self.assertIsInstance(e, ValueError)
print(e)


if __name__ == '__main__':

Expand Down