@@ -258,15 +258,16 @@ def _get_param_groups(
258
258
at the module level. Possible keys, including the "self" key do not have to
259
259
be defined. By default all parameters have the learning rate defined in the
260
260
optimizer. This can be overridden by setting the parameter group in `param_groups`
261
- member of a specific module, it can be overridden at the:
262
- - module level with “self” key, all the parameters and child
263
- module's parameters will inherit it
264
- - member level, which is the same as if the `param_groups` in that
265
- member has key=“self” and value equal to that parameter group.
261
+ member of a specific module. Values are a parameter group name. The keys
262
+ specify what parameters will be affected as follows:
263
+ - “self”: All the parameters of the module and its child modules
264
+ - name of a parameter: A parameter with that name.
265
+ - name of a module member: All the parameters of the module and its
266
+ child modules.
266
267
This is useful if members do not have `param_groups`, for
267
268
example torch.nn.Linear.
268
- - parameter level, only parameter with the same name as the key
269
- will have it .
269
+ - <name of module member>.<something>: recursive. Same as if <something>
270
+ was used in param_groups of that submodule/member .
270
271
271
272
Args:
272
273
module: module from which to extract the parameters and their parameter
@@ -277,7 +278,18 @@ def _get_param_groups(
277
278
278
279
param_groups = defaultdict (list )
279
280
280
- def traverse (module , default_group ):
281
+ def traverse (module , default_group : str , mapping : Dict [str , str ]) -> None :
282
+ """
283
+ Visitor for module to assign its parameters to the relevant member of
284
+ param_groups.
285
+
286
+ Args:
287
+ module: the module being visited in a depth-first search
288
+ default_group: the param group to assign parameters to unless
289
+ otherwise overriden.
290
+ mapping: known mappings of parameters to groups for this module,
291
+ destructively modified by this function.
292
+ """
281
293
# If key self is defined in param_groups then chenge the default param
282
294
# group for all parameters and children in the module.
283
295
if hasattr (module , "param_groups" ) and "self" in module .param_groups :
@@ -286,25 +298,26 @@ def traverse(module, default_group):
286
298
# Collect all the parameters that are directly inside the `module`,
287
299
# they will be in the default param group if they don't have
288
300
# defined group.
301
+ if hasattr (module , "param_groups" ):
302
+ mapping .update (module .param_groups )
303
+
289
304
for name , param in module .named_parameters (recurse = False ):
290
305
if param .requires_grad :
291
- if hasattr (module , "param_groups" ) and name in module .param_groups :
292
- param_groups [module .param_groups [name ]].append (param )
293
- else :
294
- param_groups [default_group ].append (param )
306
+ group_name = mapping .get (name , default_group )
307
+ logger .info (f"Assigning { name } to param_group { group_name } " )
308
+ param_groups [group_name ].append (param )
295
309
296
310
# If children have defined default param group then use it else pass
297
311
# own default.
298
312
for child_name , child in module .named_children ():
299
- if (
300
- hasattr (module , "param_groups" )
301
- and child_name in module .param_groups
302
- ):
303
- traverse (child , module .param_groups [child_name ])
304
- else :
305
- traverse (child , default_group )
306
-
307
- traverse (module , "default" )
313
+ mapping_to_add = {
314
+ name [len (child_name ) + 1 :]: group
315
+ for name , group in mapping .items ()
316
+ if name .startswith (child_name + "." )
317
+ }
318
+ traverse (child , mapping .get (child_name , default_group ), mapping_to_add )
319
+
320
+ traverse (module , "default" , {})
308
321
return param_groups
309
322
310
323
def _get_group_learning_rate (self , group_name : str ) -> float :
0 commit comments