@@ -51,7 +51,7 @@ class Parameters(object):
5151 def __init__ (self ):
5252 self .__param_conf__ = dict ()
5353 self .__gradient_machines__ = []
54- self .__tmp_params__ = []
54+ self .__tmp_params__ = dict ()
5555
5656 def __append_config__ (self , param_conf ):
5757 """
@@ -128,13 +128,10 @@ def __getitem__(self, key):
128128
129129 if len (self .__gradient_machines__ ) == 0 :
130130 # create new parameter in python numpy.
131- if len (self .__tmp_params__ ) != 0 :
132- ret_list = [
133- mat for name , mat in self .__tmp_params__ if name == key
134- ]
135- if len (ret_list ) == 1 :
136- return ret_list [0 ]
137- return np .ndarray (shape = shape , dtype = np .float32 )
131+ if key in self .__tmp_params__ :
132+ return self .__tmp_params__ [key ]
133+ else :
134+ return np .ndarray (shape = shape , dtype = np .float32 )
138135 else :
139136 for each_gradient_machine in self .__gradient_machines__ :
140137 param = __get_parameter_in_gradient_machine__ (
@@ -187,7 +184,7 @@ def __setitem__(self, key, value):
187184 (shape , value .shape ))
188185
189186 if len (self .__gradient_machines__ ) == 0 :
190- self .__tmp_params__ . append (( key , value ))
187+ self .__tmp_params__ [ key ] = value
191188 else :
192189 for each_gradient_machine in self .__gradient_machines__ :
193190 __copy_parameter_to_gradient_machine__ (each_gradient_machine ,
@@ -231,7 +228,7 @@ def append_gradient_machine(self, gradient_machine):
231228 raise ValueError ("gradient_machine should be api.GradientMachine" )
232229
233230 if len (self .__tmp_params__ ) != 0 :
234- for name , val in self .__tmp_params__ :
231+ for name , val in self .__tmp_params__ . iteritems () :
235232 try :
236233 __copy_parameter_to_gradient_machine__ (gradient_machine ,
237234 name , val )
@@ -287,6 +284,18 @@ def to_tar(self, f):
287284
288285 @staticmethod
289286 def from_tar (f ):
287+ """
288+ Create a `Parameters` object from the given file. And
289+ the `Parameters` only contains the parameters in this
290+ file. It is adapted the parameters are same in the
291+ defined network and the given file. For example, it
292+ can be used in the inference.
293+
294+ :param f: the initialized model file.
295+ :type f: tar file
296+ :return: A Parameters object.
297+ :rtype: Parameters.
298+ """
290299 params = Parameters ()
291300 tar = tarfile .TarFile (fileobj = f , mode = 'r' )
292301 for finfo in tar :
@@ -302,6 +311,21 @@ def from_tar(f):
302311 params .deserialize (param_name , f )
303312 return params
304313
314+ def init_from_tar (self , f ):
315+ """
316+ Different from `from_tar`, this interface can be used to
317+ init partial network parameters from another saved model.
318+
319+ :param f: the initialized model file.
320+ :type f: tar file
321+ :return: Nothing.
322+ """
323+
324+ tar_param = Parameters .from_tar (f )
325+ for pname in tar_param .names ():
326+ if pname in self .names ():
327+ self .set (pname , tar_param .get (pname ))
328+
305329
306330def __get_parameter_in_gradient_machine__ (gradient_machine , name ):
307331 """
0 commit comments