Skip to content

Commit 2a24c9b

Browse files
committed
Use detectron2 implementation for Attention layer in ViTDet.
1 parent 4fc15fd commit 2a24c9b

File tree

4 files changed

+482
-50
lines changed

4 files changed

+482
-50
lines changed

references/detection/train.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,25 @@ def main(args):
256256
)
257257
elif opt_name == "adamw":
258258
optimizer = torch.optim.AdamW(parameters, lr=args.lr, weight_decay=args.weight_decay)
259+
elif opt_name == "vitdet":
260+
from torchvision.models.vision_transformer import get_default_optimizer_params, get_vit_lr_decay_rate
261+
from functools import partial
262+
263+
optimizer = torch.optim.AdamW(
264+
params=get_default_optimizer_params(
265+
model,
266+
# params.model is meant to be set to the model object, before instantiating
267+
# the optimizer.
268+
base_lr=args.lr,
269+
weight_decay_norm=0.0,
270+
# TODO: Adjust num_layers for specific model. Currently this assumes ViT-B.
271+
lr_factor_func=partial(get_vit_lr_decay_rate, num_layers=12, lr_decay_rate=0.7),
272+
overrides={"pos_embed": {"weight_decay": 0.0}},
273+
),
274+
lr=args.lr,
275+
betas=(0.9, 0.999),
276+
weight_decay=0.1,
277+
)
259278
else:
260279
raise RuntimeError(f"Invalid optimizer {args.opt}. Only SGD and AdamW are supported.")
261280

torchvision/models/detection/mask_rcnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -672,7 +672,7 @@ def maskrcnn_vit_b_16_sfpn(
672672
elif num_classes is None:
673673
num_classes = 91
674674

675-
backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False)
675+
backbone = vit_b_16(weights=weights_backbone, progress=progress, include_head=False, image_size=1024)
676676
backbone = _vit_sfpn_extractor(backbone)
677677
model = MaskRCNN(backbone, num_classes=num_classes, **kwargs)
678678

0 commit comments

Comments
 (0)