Skip to content

Commit 7be49bf

Browse files
bottlerfacebook-github-bot
authored andcommitted
allow dots in param_groups
Summary: Allow a module's param_group member to specify overrides to the param groups of its members or their members. Also logging for param group assignments. This allows defining `params.basis_matrix` in the param_groups of a voxel_grid. Reviewed By: shapovalov Differential Revision: D41080667 fbshipit-source-id: 49f3b0e5b36e496f78701db0699cbb8a7e20c51e
1 parent a1f2ded commit 7be49bf

File tree

2 files changed

+57
-23
lines changed

2 files changed

+57
-23
lines changed

projects/implicitron_trainer/impl/optimizer_factory.py

Lines changed: 34 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -258,15 +258,16 @@ def _get_param_groups(
258258
at the module level. Possible keys, including the "self" key do not have to
259259
be defined. By default all parameters have the learning rate defined in the
260260
optimizer. This can be overridden by setting the parameter group in `param_groups`
261-
member of a specific module, it can be overridden at the:
262-
- module level with “self” key, all the parameters and child
263-
module's parameters will inherit it
264-
- member level, which is the same as if the `param_groups` in that
265-
member has key=“self” and value equal to that parameter group.
261+
member of a specific module. Values are a parameter group name. The keys
262+
specify what parameters will be affected as follows:
263+
- “self”: All the parameters of the module and its child modules
264+
- name of a parameter: A parameter with that name.
265+
- name of a module member: All the parameters of the module and its
266+
child modules.
266267
This is useful if members do not have `param_groups`, for
267268
example torch.nn.Linear.
268-
- parameter level, only parameter with the same name as the key
269-
will have it.
269+
- <name of module member>.<something>: recursive. Same as if <something>
270+
was used in param_groups of that submodule/member.
270271
271272
Args:
272273
module: module from which to extract the parameters and their parameter
@@ -277,7 +278,18 @@ def _get_param_groups(
277278

278279
param_groups = defaultdict(list)
279280

280-
def traverse(module, default_group):
281+
def traverse(module, default_group: str, mapping: Dict[str, str]) -> None:
282+
"""
283+
Visitor for module to assign its parameters to the relevant member of
284+
param_groups.
285+
286+
Args:
287+
module: the module being visited in a depth-first search
288+
default_group: the param group to assign parameters to unless
289+
otherwise overriden.
290+
mapping: known mappings of parameters to groups for this module,
291+
destructively modified by this function.
292+
"""
281293
# If key self is defined in param_groups then chenge the default param
282294
# group for all parameters and children in the module.
283295
if hasattr(module, "param_groups") and "self" in module.param_groups:
@@ -286,25 +298,26 @@ def traverse(module, default_group):
286298
# Collect all the parameters that are directly inside the `module`,
287299
# they will be in the default param group if they don't have
288300
# defined group.
301+
if hasattr(module, "param_groups"):
302+
mapping.update(module.param_groups)
303+
289304
for name, param in module.named_parameters(recurse=False):
290305
if param.requires_grad:
291-
if hasattr(module, "param_groups") and name in module.param_groups:
292-
param_groups[module.param_groups[name]].append(param)
293-
else:
294-
param_groups[default_group].append(param)
306+
group_name = mapping.get(name, default_group)
307+
logger.info(f"Assigning {name} to param_group {group_name}")
308+
param_groups[group_name].append(param)
295309

296310
# If children have defined default param group then use it else pass
297311
# own default.
298312
for child_name, child in module.named_children():
299-
if (
300-
hasattr(module, "param_groups")
301-
and child_name in module.param_groups
302-
):
303-
traverse(child, module.param_groups[child_name])
304-
else:
305-
traverse(child, default_group)
306-
307-
traverse(module, "default")
313+
mapping_to_add = {
314+
name[len(child_name) + 1 :]: group
315+
for name, group in mapping.items()
316+
if name.startswith(child_name + ".")
317+
}
318+
traverse(child, mapping.get(child_name, default_group), mapping_to_add)
319+
320+
traverse(module, "default", {})
308321
return param_groups
309322

310323
def _get_group_learning_rate(self, group_name: str) -> float:

projects/implicitron_trainer/tests/test_optimizer_factory.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,17 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import logging
78
import os
89
import unittest
910

1011
import torch
1112
from pytorch3d.implicitron.tools.config import expand_args_fields, get_default_args
1213

13-
from ..impl.optimizer_factory import ImplicitronOptimizerFactory
14+
from ..impl.optimizer_factory import (
15+
ImplicitronOptimizerFactory,
16+
logger as factory_logger,
17+
)
1418

1519
internal = os.environ.get("FB_TEST", False)
1620

@@ -23,9 +27,17 @@ def setUp(self) -> None:
2327
def _get_param_groups(self, model):
2428
default_cfg = get_default_args(ImplicitronOptimizerFactory)
2529
factory = ImplicitronOptimizerFactory(default_cfg)
26-
return factory._get_param_groups(model)
30+
oldlevel = factory_logger.level
31+
factory_logger.setLevel(logging.ERROR)
32+
out = factory._get_param_groups(model)
33+
factory_logger.setLevel(oldlevel)
34+
return out
2735

2836
def _assert_allin(self, a, param_groups, key):
37+
"""
38+
Asserts that all the parameters in a are in the group
39+
named by key.
40+
"""
2941
with self.subTest(f"Testing key {key}"):
3042
b = param_groups[key]
3143
for el in a:
@@ -83,6 +95,15 @@ def test_no_param_groups_defined(self):
8395
param_groups = self._get_param_groups(root)
8496
self._assert_allin([pa, pb, pc], param_groups, "default")
8597

98+
def test_double_dotted(self):
99+
pa, pb = [torch.nn.Parameter(data=torch.tensor(i * 1.0)) for i in range(2)]
100+
na = Node(params=[pa, pb])
101+
nb = Node(children=[na])
102+
root = Node(children=[nb], param_groups={"m0.m0.p0": "X", "m0.m0": "Y"})
103+
param_groups = self._get_param_groups(root)
104+
self._assert_allin([pa], param_groups, "X")
105+
self._assert_allin([pb], param_groups, "Y")
106+
86107
def test_tree_param_groups_defined(self):
87108
"""
88109
Test generic tree assignment.

0 commit comments

Comments
 (0)