Skip to content

Commit aaee28b

Browse files
authored
Merge pull request #2664 from qingqing01/from_tar
Init partial network parameters from another saved model.
2 parents 9af8d86 + 23d6c59 commit aaee28b

File tree

3 files changed

+87
-16
lines changed

3 files changed

+87
-16
lines changed

paddle/py_paddle/dataprovider_converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def finish_scan(self, argument):
144144
if len(self.__shape__) > 1:
145145
# The last-two dimenstions are the frame height and width.
146146
# For example, the layout is CHW for 3-D feature of image.
147-
# The H and W are the fram height and width.
147+
# The H and W are the frame height and width.
148148
h, w = self.__shape__[-2:]
149149
argument.setSlotFrameHeight(self.pos, h)
150150
argument.setSlotFrameWidth(self.pos, w)

python/paddle/v2/parameters.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ class Parameters(object):
5151
def __init__(self):
5252
self.__param_conf__ = dict()
5353
self.__gradient_machines__ = []
54-
self.__tmp_params__ = []
54+
self.__tmp_params__ = dict()
5555

5656
def __append_config__(self, param_conf):
5757
"""
@@ -128,13 +128,10 @@ def __getitem__(self, key):
128128

129129
if len(self.__gradient_machines__) == 0:
130130
# create new parameter in python numpy.
131-
if len(self.__tmp_params__) != 0:
132-
ret_list = [
133-
mat for name, mat in self.__tmp_params__ if name == key
134-
]
135-
if len(ret_list) == 1:
136-
return ret_list[0]
137-
return np.ndarray(shape=shape, dtype=np.float32)
131+
if key in self.__tmp_params__:
132+
return self.__tmp_params__[key]
133+
else:
134+
return np.ndarray(shape=shape, dtype=np.float32)
138135
else:
139136
for each_gradient_machine in self.__gradient_machines__:
140137
param = __get_parameter_in_gradient_machine__(
@@ -187,7 +184,7 @@ def __setitem__(self, key, value):
187184
(shape, value.shape))
188185

189186
if len(self.__gradient_machines__) == 0:
190-
self.__tmp_params__.append((key, value))
187+
self.__tmp_params__[key] = value
191188
else:
192189
for each_gradient_machine in self.__gradient_machines__:
193190
__copy_parameter_to_gradient_machine__(each_gradient_machine,
@@ -231,7 +228,7 @@ def append_gradient_machine(self, gradient_machine):
231228
raise ValueError("gradient_machine should be api.GradientMachine")
232229

233230
if len(self.__tmp_params__) != 0:
234-
for name, val in self.__tmp_params__:
231+
for name, val in self.__tmp_params__.iteritems():
235232
try:
236233
__copy_parameter_to_gradient_machine__(gradient_machine,
237234
name, val)
@@ -287,6 +284,18 @@ def to_tar(self, f):
287284

288285
@staticmethod
289286
def from_tar(f):
287+
"""
288+
Create a `Parameters` object from the given file. And
289+
the `Parameters` only contains the parameters in this
290+
file. It is adapted the parameters are same in the
291+
defined network and the given file. For example, it
292+
can be used in the inference.
293+
294+
:param f: the initialized model file.
295+
:type f: tar file
296+
:return: A Parameters object.
297+
:rtype: Parameters.
298+
"""
290299
params = Parameters()
291300
tar = tarfile.TarFile(fileobj=f, mode='r')
292301
for finfo in tar:
@@ -302,6 +311,21 @@ def from_tar(f):
302311
params.deserialize(param_name, f)
303312
return params
304313

314+
def init_from_tar(self, f):
315+
"""
316+
Different from `from_tar`, this interface can be used to
317+
init partial network parameters from another saved model.
318+
319+
:param f: the initialized model file.
320+
:type f: tar file
321+
:return: Nothing.
322+
"""
323+
324+
tar_param = Parameters.from_tar(f)
325+
for pname in tar_param.names():
326+
if pname in self.names():
327+
self.set(pname, tar_param.get(pname))
328+
305329

306330
def __get_parameter_in_gradient_machine__(gradient_machine, name):
307331
"""

python/paddle/v2/tests/test_parameters.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,17 @@
2020
import numpy
2121

2222

23-
def __rand_param_config__(name):
23+
def __rand_param_config__(name, psize=None):
2424
conf = ParameterConfig()
2525
conf.name = name
2626
size = 1
27-
for i in xrange(2):
28-
dim = random.randint(1, 1000)
29-
conf.dims.append(dim)
30-
size *= dim
27+
if psize is None:
28+
for i in xrange(2):
29+
dim = random.randint(1, 1000)
30+
conf.dims.append(dim)
31+
size *= dim
32+
else:
33+
size = psize
3134
conf.size = size
3235
assert conf.IsInitialized()
3336
return conf
@@ -77,6 +80,50 @@ def initializer(name):
7780
expected = numpy.array([[1, 1], [1, 2], [1, 1]], numpy.float32)
7881
assert numpy.logical_and.reduce(numpy.reshape(val == expected, 6))
7982

83+
def test_init_from_tar(self):
84+
def get_param(names, size):
85+
p = parameters.Parameters()
86+
for k, v in zip(names, size):
87+
p.__append_config__(__rand_param_config__(k, v))
88+
for name in p.names():
89+
param = p.get(name)
90+
param[:] = numpy.random.uniform(
91+
-1.0, 1.0, size=p.get_shape(name))
92+
p.set(name, param)
93+
return p
94+
95+
def get_parames():
96+
name1 = ['param_0', 'param_1']
97+
size1 = [128, 256]
98+
p1 = get_param(name1, size1)
99+
file1 = cStringIO.StringIO()
100+
p1.to_tar(file1)
101+
file1.seek(0)
102+
103+
name2 = ['param_0', 'param_1', 'param_2']
104+
size2 = [128, 256, 288]
105+
p2 = get_param(name2, size2)
106+
file2 = cStringIO.StringIO()
107+
p2.to_tar(file2)
108+
file2.seek(0)
109+
return p1, file1, p2, file2
110+
111+
p1, file1, p2, file2 = get_parames()
112+
p2.init_from_tar(file1)
113+
for name in p1.names():
114+
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
115+
v1 = p1.get(name)
116+
v2 = p2.get(name)
117+
self.assertTrue(numpy.isclose(v1, v2).all())
118+
119+
p1, file1, p2, file2 = get_parames()
120+
p1.init_from_tar(file2)
121+
for name in p1.names():
122+
self.assertEqual(p1.get_shape(name), p2.get_shape(name))
123+
v1 = p1.get(name)
124+
v2 = p2.get(name)
125+
self.assertTrue(numpy.isclose(v1, v2).all())
126+
80127

81128
if __name__ == '__main__':
82129
unittest.main()

0 commit comments

Comments
 (0)