Skip to content

Commit 5724c1c

Browse files
committed
Support custom weight decay for Normalization layers.
1 parent 33a90f7 commit 5724c1c

File tree

3 files changed

+52
-10
lines changed

3 files changed

+52
-10
lines changed

references/classification/train.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -197,16 +197,22 @@ def main(args):
197197

198198
criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)
199199

200+
if args.norm_weight_decay is None:
201+
parameters = model.parameters()
202+
else:
203+
param_groups = torchvision.ops._utils.split_normalization_params(model)
204+
wd_groups = [args.norm_weight_decay, args.weight_decay]
205+
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
206+
200207
opt_name = args.opt.lower()
201208
if opt_name.startswith("sgd"):
202-
optimizer = torch.optim.SGD(
203-
model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
204-
nesterov="nesterov" in opt_name)
209+
optimizer = torch.optim.SGD(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
210+
nesterov="nesterov" in opt_name)
205211
elif opt_name == 'rmsprop':
206-
optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum,
207-
weight_decay=args.weight_decay, eps=0.0316, alpha=0.9)
212+
optimizer = torch.optim.RMSprop(parameters, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay,
213+
eps=0.0316, alpha=0.9)
208214
elif opt_name == 'adamw':
209-
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
215+
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
210216
else:
211217
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD, RMSprop and AdamW are supported.")
212218

@@ -326,6 +332,8 @@ def get_args_parser(add_help=True):
326332
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
327333
metavar='W', help='weight decay (default: 1e-4)',
328334
dest='weight_decay')
335+
parser.add_argument('--norm-weight-decay', default=None, type=float,
336+
help='weight decay for Normalization layers (default: None, same value as --wd)')
329337
parser.add_argument('--label-smoothing', default=0.0, type=float,
330338
help='label smoothing (default: 0.0)',
331339
dest='label_smoothing')

test/test_ops.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
from PIL import Image
1010
import torch
1111
from functools import lru_cache
12-
from torch import Tensor
12+
from torch import nn, Tensor
1313
from torch.autograd import gradcheck
1414
from torch.nn.modules.utils import _pair
15-
from torchvision import ops
15+
from torchvision import models, ops
1616
from typing import Tuple
1717

1818

@@ -1062,5 +1062,15 @@ def test_stochastic_depth(self, mode, p):
10621062
assert p_value > 0.0001
10631063

10641064

1065+
class TestUtils:
1066+
@pytest.mark.parametrize('norm_layer', [None, nn.BatchNorm2d, nn.LayerNorm])
1067+
def test_split_normalization_params(self, norm_layer):
1068+
model = models.mobilenet_v3_large(norm_layer=norm_layer)
1069+
params = ops._utils.split_normalization_params(model, None if norm_layer is None else [norm_layer])
1070+
1071+
assert len(params[0]) == 92
1072+
assert len(params[1]) == 82
1073+
1074+
10651075
if __name__ == '__main__':
10661076
pytest.main([__file__])

torchvision/ops/_utils.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
2-
from torch import Tensor
3-
from typing import List, Union
2+
from torch import nn, Tensor
3+
from typing import List, Optional, Tuple, Union
44

55

66
def _cat(tensors: List[Tensor], dim: int = 0) -> Tensor:
@@ -34,3 +34,27 @@ def check_roi_boxes_shape(boxes: Union[Tensor, List[Tensor]]):
3434
else:
3535
assert False, 'boxes is expected to be a Tensor[L, 5] or a List[Tensor[K, 4]]'
3636
return
37+
38+
39+
def split_normalization_params(model: nn.Module,
40+
norm_classes: Optional[List[type]] = None) -> Tuple[List[Tensor], List[Tensor]]:
41+
# Adapted from https://github.com/facebookresearch/ClassyVision/blob/659d7f78/classy_vision/generic/util.py#L501
42+
if not norm_classes:
43+
norm_classes = [nn.modules.batchnorm._BatchNorm, nn.LayerNorm, nn.GroupNorm]
44+
45+
for t in norm_classes:
46+
if not issubclass(t, nn.Module):
47+
raise ValueError(f"Class {t} is not a subclass of nn.Module.")
48+
49+
classes = tuple(norm_classes)
50+
51+
norm_params = []
52+
other_params = []
53+
for module in model.modules():
54+
if next(module.children(), None):
55+
other_params.extend(p for p in module.parameters(recurse=False) if p.requires_grad)
56+
elif isinstance(module, classes):
57+
norm_params.extend(p for p in module.parameters() if p.requires_grad)
58+
else:
59+
other_params.extend(p for p in module.parameters() if p.requires_grad)
60+
return norm_params, other_params

0 commit comments

Comments
 (0)