Skip to content

add set_weight_decay to support custom weight decay setting #5671

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 30 commits into from
Apr 1, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
c5a1821
add set_weight_decay
xiaohu2015 Mar 24, 2022
639bd97
Merge branch 'main' into patch-1
xiaohu2015 Mar 25, 2022
3955d44
Update _utils.py
xiaohu2015 Mar 25, 2022
568d515
refactor code
xiaohu2015 Mar 25, 2022
c56a01d
fix import
xiaohu2015 Mar 27, 2022
c572393
Merge branch 'main' into patch-1
xiaohu2015 Mar 27, 2022
d109343
add set_weight_decay
xiaohu2015 Mar 27, 2022
b84da51
fix lint
xiaohu2015 Mar 27, 2022
8cc4eeb
fix lint
xiaohu2015 Mar 27, 2022
a812387
replace split_normalization_params with set_weight_decay
xiaohu2015 Mar 27, 2022
4f2c206
simplfy the code
xiaohu2015 Mar 28, 2022
5f2d527
Merge branch 'main' into patch-1
xiaohu2015 Mar 28, 2022
f12bd08
refactor code
xiaohu2015 Mar 30, 2022
f9f7f18
refactor code
xiaohu2015 Mar 30, 2022
3529782
fix lint
xiaohu2015 Mar 30, 2022
5fbbb2e
Merge branch 'main' into patch-1
xiaohu2015 Mar 30, 2022
78116ba
Merge branch 'main' into patch-1
xiaohu2015 Mar 31, 2022
d0a0efc
remove unused
xiaohu2015 Mar 31, 2022
2867964
Update test_ops.py
xiaohu2015 Mar 31, 2022
e4aba9d
Update train.py
xiaohu2015 Mar 31, 2022
489227c
Merge branch 'main' into patch-1
xiaohu2015 Mar 31, 2022
bbc0005
Update _utils.py
xiaohu2015 Mar 31, 2022
5c0bd12
Update train.py
xiaohu2015 Mar 31, 2022
a65b43f
Merge branch 'main' into patch-1
xiaohu2015 Mar 31, 2022
f0635dd
add set_weight_decay
xiaohu2015 Apr 1, 2022
3564ae6
add set_weight_decay
xiaohu2015 Apr 1, 2022
8251dad
Update _utils.py
xiaohu2015 Apr 1, 2022
5a0b0cc
Update test_ops.py
xiaohu2015 Apr 1, 2022
62ca42b
Merge branch 'main' into patch-1
datumbox Apr 1, 2022
755ccd5
Change `--transformer-weight-decay` to `--transformer-embedding-decay`
datumbox Apr 1, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions references/classification/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,12 +229,18 @@ def main(args):

criterion = nn.CrossEntropyLoss(label_smoothing=args.label_smoothing)

if args.norm_weight_decay is None:
parameters = [p for p in model.parameters() if p.requires_grad]
else:
param_groups = torchvision.ops._utils.split_normalization_params(model)
wd_groups = [args.norm_weight_decay, args.weight_decay]
parameters = [{"params": p, "weight_decay": w} for p, w in zip(param_groups, wd_groups) if p]
custom_keys_weight_decay = []
if args.bias_weight_decay is not None:
custom_keys_weight_decay.append(("bias", args.bias_weight_decay))
if args.transformer_embedding_decay is not None:
for key in ["class_token", "position_embedding", "relative_position_bias"]:
custom_keys_weight_decay.append((key, args.transformer_embedding_decay))
parameters = utils.set_weight_decay(
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None,
)

opt_name = args.opt.lower()
if opt_name.startswith("sgd"):
Expand Down Expand Up @@ -393,6 +399,18 @@ def get_args_parser(add_help=True):
type=float,
help="weight decay for Normalization layers (default: None, same value as --wd)",
)
parser.add_argument(
"--bias-weight-decay",
default=None,
type=float,
help="weight decay for bias parameters of all layers (default: None, same value as --wd)",
)
parser.add_argument(
"--transformer-embedding-decay",
default=None,
type=float,
help="weight decay for embedding parameters for vision transformer models (default: None, same value as --wd)",
)
parser.add_argument(
"--label-smoothing", default=0.0, type=float, help="label smoothing (default: 0.0)", dest="label_smoothing"
)
Expand Down
63 changes: 63 additions & 0 deletions references/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import time
from collections import defaultdict, deque, OrderedDict
from typing import List, Optional, Tuple

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -400,3 +401,65 @@ def reduce_across_processes(val):
dist.barrier()
dist.all_reduce(t)
return t


def set_weight_decay(
model: torch.nn.Module,
weight_decay: float,
norm_weight_decay: Optional[float] = None,
norm_classes: Optional[List[type]] = None,
custom_keys_weight_decay: Optional[List[Tuple[str, float]]] = None,
):
if not norm_classes:
norm_classes = [
torch.nn.modules.batchnorm._BatchNorm,
torch.nn.LayerNorm,
torch.nn.GroupNorm,
torch.nn.modules.instancenorm._InstanceNorm,
torch.nn.LocalResponseNorm,
]
norm_classes = tuple(norm_classes)

params = {
"other": [],
"norm": [],
}
params_weight_decay = {
"other": weight_decay,
"norm": norm_weight_decay,
}
custom_keys = []
if custom_keys_weight_decay is not None:
for key, weight_decay in custom_keys_weight_decay:
params[key] = []
params_weight_decay[key] = weight_decay
custom_keys.append(key)

def _add_params(module, prefix=""):
for name, p in module.named_parameters(recurse=False):
if not p.requires_grad:
continue
is_custom_key = False
for key in custom_keys:
target_name = f"{prefix}.{name}" if prefix != "" and "." in key else name
if key == target_name:
params[key].append(p)
is_custom_key = True
break
if not is_custom_key:
if norm_weight_decay is not None and isinstance(module, norm_classes):
params["norm"].append(p)
else:
params["other"].append(p)

for child_name, child_module in module.named_children():
child_prefix = f"{prefix}.{child_name}" if prefix != "" else child_name
_add_params(child_module, prefix=child_prefix)

_add_params(model)

param_groups = []
for key in params:
if len(params[key]) > 0:
param_groups.append({"params": params[key], "weight_decay": params_weight_decay[key]})
return param_groups