Skip to content

Commit 5966cee

Browse files
committed
Use detectron2 implementation for Attention layer in ViTDet.
1 parent 4fc15fd commit 5966cee

File tree

3 files changed

+477
-46
lines changed

3 files changed

+477
-46
lines changed

references/detection/train.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,25 @@ def main(args):
256256
)
257257
elif opt_name == "adamw":
258258
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
259+
elif opt_name == "vitdet":
260+
from torchvision.models.vision_transformer import get_default_optimizer_params, get_vit_lr_decay_rate
261+
from functools import partial
262+
263+
optimizer = torch.optim.AdamW(
264+
params=get_default_optimizer_params(
265+
model,
266+
# params.model is meant to be set to the model object, before instantiating
267+
# the optimizer.
268+
base_lr=args.lr,
269+
weight_decay_norm=0.0,
270+
# TODO: Adjust num_layers for specific model. Currently this assumes ViT-B.
271+
lr_factor_func=partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7),
272+
overrides={"pos_embed": {"weight_decay": 0.0}},
273+
),
274+
lr=args.lr,
275+
betas=(0.9, 0.999),
276+
weight_decay=0.1,
277+
)
259278
else:
260279
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
261280

0 commit comments

Comments
 (0)