Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 83 additions & 0 deletions python/paddle/v2/fluid/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,3 +285,86 @@ def __call__(self, var, block):
})
var.op = op
return op


class MSRAInitializer(Initializer):
"""Implements the MSRA initializer a.k.a. Kaiming Initializer

This class implements the weight initialization from the paper
Delving Deep into Rectifiers: Surpassing Human-Level Performance on
ImageNet Classification[1] by Kaiming He, Xiangyu Zhang, Shaoqing Ren
and Jian Sun. This is a robust initialization method that particularly
considers the rectifier nonlinearities. In case of Uniform distribution,
the range is [-x, x], where x = sqrt(6 / fan_in). In case of Normal
distribution, the mean is 0 and the standard deviation
is sqrt(2/ fan_in).

References:
[1] Delving Deep into Rectifiers: Surpassing Human-Level Performance
on ImageNet Classification
(https://arxiv.org/abs/1502.01852)
"""

def __init__(self, uniform=True, fan_in=None, seed=0):
"""Constructor for MSRAInitializer

Args:
uniform: whether to use uniform or normal distribution
fan_in: fan_in for MSRAInitializer. If None, it is
inferred from the variable.
seed: random seed

Note: It is recommended to set fan_in to None for most cases.
"""
assert uniform is not None
assert seed is not None
super(MSRAInitializer, self).__init__()
self._uniform = uniform
self._fan_in = fan_in
self._seed = seed

def __call__(self, var, block):
"""Add MSRA initialization ops for a variable

Args:
var: Variable that needs to be initialized
block: The block in which initialization ops
should be added

Returns:
the initialization op
"""
assert isinstance(var, framework.Variable)
assert isinstance(block, framework.Block)
f_in, f_out = self._compute_fans(var)

# If fan_in is passed, use it
fan_in = f_in if self._fan_in is None else self._fan_in

if self._uniform:
limit = np.sqrt(6.0 / float(fan_in))
op = block.prepend_op(
type="uniform_random",
outputs={"Out": var},
attrs={
"shape": var.shape,
"data_type": int(var.data_type),
"min": -limit,
"max": limit,
"seed": self._seed
})

else:
std = np.sqrt(2.0 / float(fan_in))
op = block.prepend_op(
type="gaussian_random",
outputs={"Out": var},
attrs={
"shape": var.shape,
"data_type": int(var.data_type),
"mean": 0.0,
"std": std,
"seed": self._seed
})
var.op = op
return op
104 changes: 104 additions & 0 deletions python/paddle/v2/fluid/tests/test_initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,5 +223,109 @@ def test_xavier_initializer_supplied_arguments(self):
self.assertEqual(init_op.attr('seed'), 134)


class TestMSRAInitializer(unittest.TestCase):
def test_uniform_msra_initializer(self):
"""Test MSRA initializer with uniform distribution on
for matrix multiply.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)

def test_uniform_msra_initializer_conv(self):
"""Test MSRA initializer with uniform distribution on
for convolutions.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer())
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
receptive_field_size = float(15 * 20)
limit = np.sqrt(6.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)

def test_normal_msra_initializer(self):
"""Test MSRA initializer with normal distribution on
for matrix multiply.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(uniform=False))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
std = np.sqrt(2.0 / param.shape[0])
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)

def test_normal_msra_initializer_conv(self):
"""Test MSRA initializer with normal distribution on
for convolutions.
"""
program = framework.Program()
block = program.global_block()
param = block.create_parameter(
dtype="float32",
shape=[5, 10, 15, 20],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(uniform=False))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'gaussian_random')
receptive_field_size = float(15 * 20)
std = np.sqrt(2.0 / (param.shape[1] * receptive_field_size))
self.assertAlmostEqual(init_op.attr('mean'), 0.0, delta=DELTA)
self.assertAlmostEqual(init_op.attr('std'), std, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 0)

def test_msra_initializer_supplied_arguments(self):
"""Test the MSRA initializer with supplied arguments
"""
program = framework.Program()
block = program.global_block()
block.create_parameter(
dtype="float32",
shape=[5, 10],
lod_level=0,
name="param",
initializer=initializer.MSRAInitializer(
fan_in=12, seed=134))
self.assertEqual(len(block.ops), 1)
init_op = block.ops[0]
self.assertEqual(init_op.type, 'uniform_random')
limit = np.sqrt(6.0 / 12)
self.assertAlmostEqual(init_op.attr('min'), -limit, delta=DELTA)
self.assertAlmostEqual(init_op.attr('max'), limit, delta=DELTA)
self.assertEqual(init_op.attr('seed'), 134)


if __name__ == '__main__':
unittest.main()