@@ -258,15 +258,16 @@ def _get_param_groups(
258258 at the module level. Possible keys, including the "self" key do not have to
259259 be defined. By default all parameters have the learning rate defined in the
260260 optimizer. This can be overridden by setting the parameter group in `param_groups`
261- member of a specific module, it can be overridden at the:
262- - module level with “self” key, all the parameters and child
263- module's parameters will inherit it
264- - member level, which is the same as if the `param_groups` in that
265- member has key=“self” and value equal to that parameter group.
261+ member of a specific module. Values are a parameter group name. The keys
262+ specify what parameters will be affected as follows:
263+ - “self”: All the parameters of the module and its child modules
264+ - name of a parameter: A parameter with that name.
265+ - name of a module member: All the parameters of the module and its
266+ child modules.
266267 This is useful if members do not have `param_groups`, for
267268 example torch.nn.Linear.
268- - parameter level, only parameter with the same name as the key
269- will have it .
269+ - <name of module member>.<something>: recursive. Same as if <something>
270+ was used in param_groups of that submodule/member .
270271
271272 Args:
272273 module: module from which to extract the parameters and their parameter
@@ -277,7 +278,18 @@ def _get_param_groups(
277278
278279 param_groups = defaultdict (list )
279280
280- def traverse (module , default_group ):
281+ def traverse (module , default_group : str , mapping : Dict [str , str ]) -> None :
282+ """
283+ Visitor for module to assign its parameters to the relevant member of
284+ param_groups.
285+
286+ Args:
287+ module: the module being visited in a depth-first search
288+ default_group: the param group to assign parameters to unless
289+ otherwise overriden.
290+ mapping: known mappings of parameters to groups for this module,
291+ destructively modified by this function.
292+ """
281293 # If key self is defined in param_groups then chenge the default param
282294 # group for all parameters and children in the module.
283295 if hasattr (module , "param_groups" ) and "self" in module .param_groups :
@@ -286,25 +298,26 @@ def traverse(module, default_group):
286298 # Collect all the parameters that are directly inside the `module`,
287299 # they will be in the default param group if they don't have
288300 # defined group.
301+ if hasattr (module , "param_groups" ):
302+ mapping .update (module .param_groups )
303+
289304 for name , param in module .named_parameters (recurse = False ):
290305 if param .requires_grad :
291- if hasattr (module , "param_groups" ) and name in module .param_groups :
292- param_groups [module .param_groups [name ]].append (param )
293- else :
294- param_groups [default_group ].append (param )
306+ group_name = mapping .get (name , default_group )
307+ logger .info (f"Assigning { name } to param_group { group_name } " )
308+ param_groups [group_name ].append (param )
295309
296310 # If children have defined default param group then use it else pass
297311 # own default.
298312 for child_name , child in module .named_children ():
299- if (
300- hasattr (module , "param_groups" )
301- and child_name in module .param_groups
302- ):
303- traverse (child , module .param_groups [child_name ])
304- else :
305- traverse (child , default_group )
306-
307- traverse (module , "default" )
313+ mapping_to_add = {
314+ name [len (child_name ) + 1 :]: group
315+ for name , group in mapping .items ()
316+ if name .startswith (child_name + "." )
317+ }
318+ traverse (child , mapping .get (child_name , default_group ), mapping_to_add )
319+
320+ traverse (module , "default" , {})
308321 return param_groups
309322
310323 def _get_group_learning_rate (self , group_name : str ) -> float :
0 commit comments