Skip to content

Commit d18c487

Browse files
authored
Refactor the backbone builders of detection (#4656)
* Refactoring resnet_fpn backbone building. * Passing the change to *_rcnn and retinanet. * Applying for faster_rcnn + mobilenetv3 * Applying for ssdlite + mobilenetv3 * Applying for ssd + vgg16 * Update the expected file of retinanet_resnet50_fpn to fix order of initialization. * Adding full model weights for the VGG16 features.
1 parent 5e84bab commit d18c487

File tree

10 files changed

+77
-78
lines changed

10 files changed

+77
-78
lines changed
Binary file not shown.

torchvision/models/detection/backbone_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Callable, Dict, Optional, List
2+
from typing import Callable, Dict, Optional, List, Union
33

44
from torch import nn, Tensor
55
from torchvision.ops import misc as misc_nn_ops
@@ -100,14 +100,14 @@ def resnet_fpn_backbone(
100100
default a ``LastLevelMaxPool`` is used.
101101
"""
102102
backbone = resnet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
103-
return _resnet_backbone_config(backbone, trainable_layers, returned_layers, extra_blocks)
103+
return _resnet_fpn_extractor(backbone, trainable_layers, returned_layers, extra_blocks)
104104

105105

106-
def _resnet_backbone_config(
106+
def _resnet_fpn_extractor(
107107
backbone: resnet.ResNet,
108108
trainable_layers: int,
109-
returned_layers: Optional[List[int]],
110-
extra_blocks: Optional[ExtraFPNBlock],
109+
returned_layers: Optional[List[int]] = None,
110+
extra_blocks: Optional[ExtraFPNBlock] = None,
111111
) -> BackboneWithFPN:
112112

113113
# select layers that wont be frozen
@@ -165,9 +165,18 @@ def mobilenet_backbone(
165165
returned_layers: Optional[List[int]] = None,
166166
extra_blocks: Optional[ExtraFPNBlock] = None,
167167
) -> nn.Module:
168+
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer)
169+
return _mobilenet_extractor(backbone, fpn, trainable_layers, returned_layers, extra_blocks)
168170

169-
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features
170171

172+
def _mobilenet_extractor(
173+
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
174+
fpn: bool,
175+
trainable_layers,
176+
returned_layers: Optional[List[int]] = None,
177+
extra_blocks: Optional[ExtraFPNBlock] = None,
178+
) -> nn.Module:
179+
backbone = backbone.features
171180
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
172181
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
173182
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]

torchvision/models/detection/faster_rcnn.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,12 @@
33
from torchvision.ops import MultiScaleRoIAlign
44

55
from ..._internally_replaced_utils import load_state_dict_from_url
6+
from ...ops import misc as misc_nn_ops
7+
from ..mobilenetv3 import mobilenet_v3_large
8+
from ..resnet import resnet50
69
from ._utils import overwrite_eps
710
from .anchor_utils import AnchorGenerator
8-
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone
11+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers, _mobilenet_extractor
912
from .generalized_rcnn import GeneralizedRCNN
1013
from .roi_heads import RoIHeads
1114
from .rpn import RPNHead, RegionProposalNetwork
@@ -385,7 +388,9 @@ def fasterrcnn_resnet50_fpn(
385388
if pretrained:
386389
# no need to download the backbone if pretrained is set
387390
pretrained_backbone = False
388-
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
391+
392+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
393+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
389394
model = FasterRCNN(backbone, num_classes, **kwargs)
390395
if pretrained:
391396
state_dict = load_state_dict_from_url(model_urls["fasterrcnn_resnet50_fpn_coco"], progress=progress)
@@ -409,9 +414,11 @@ def _fasterrcnn_mobilenet_v3_large_fpn(
409414

410415
if pretrained:
411416
pretrained_backbone = False
412-
backbone = mobilenet_backbone(
413-
"mobilenet_v3_large", pretrained_backbone, True, trainable_layers=trainable_backbone_layers
417+
418+
backbone = mobilenet_v3_large(
419+
pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d
414420
)
421+
backbone = _mobilenet_extractor(backbone, True, trainable_backbone_layers)
415422

416423
anchor_sizes = (
417424
(

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from torchvision.ops import MultiScaleRoIAlign
44

55
from ..._internally_replaced_utils import load_state_dict_from_url
6+
from ...ops import misc as misc_nn_ops
7+
from ..resnet import resnet50
68
from ._utils import overwrite_eps
7-
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
9+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
810
from .faster_rcnn import FasterRCNN
911

1012

@@ -367,7 +369,9 @@ def keypointrcnn_resnet50_fpn(
367369
if pretrained:
368370
# no need to download the backbone if pretrained is set
369371
pretrained_backbone = False
370-
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
372+
373+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
374+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
371375
model = KeypointRCNN(backbone, num_classes, num_keypoints=num_keypoints, **kwargs)
372376
if pretrained:
373377
key = "keypointrcnn_resnet50_fpn_coco"

torchvision/models/detection/mask_rcnn.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from torchvision.ops import MultiScaleRoIAlign
55

66
from ..._internally_replaced_utils import load_state_dict_from_url
7+
from ...ops import misc as misc_nn_ops
8+
from ..resnet import resnet50
79
from ._utils import overwrite_eps
8-
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
10+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
911
from .faster_rcnn import FasterRCNN
1012

1113
__all__ = [
@@ -364,7 +366,9 @@ def maskrcnn_resnet50_fpn(
364366
if pretrained:
365367
# no need to download the backbone if pretrained is set
366368
pretrained_backbone = False
367-
backbone = resnet_fpn_backbone("resnet50", pretrained_backbone, trainable_layers=trainable_backbone_layers)
369+
370+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
371+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
368372
model = MaskRCNN(backbone, num_classes, **kwargs)
369373
if pretrained:
370374
state_dict = load_state_dict_from_url(model_urls["maskrcnn_resnet50_fpn_coco"], progress=progress)

torchvision/models/detection/retinanet.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from ..._internally_replaced_utils import load_state_dict_from_url
1010
from ...ops import sigmoid_focal_loss
1111
from ...ops import boxes as box_ops
12+
from ...ops import misc as misc_nn_ops
1213
from ...ops.feature_pyramid_network import LastLevelP6P7
14+
from ..resnet import resnet50
1315
from . import _utils as det_utils
1416
from ._utils import overwrite_eps
1517
from .anchor_utils import AnchorGenerator
16-
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
18+
from .backbone_utils import _resnet_fpn_extractor, _validate_trainable_layers
1719
from .transform import GeneralizedRCNNTransform
1820

1921

@@ -630,13 +632,11 @@ def retinanet_resnet50_fpn(
630632
if pretrained:
631633
# no need to download the backbone if pretrained is set
632634
pretrained_backbone = False
635+
636+
backbone = resnet50(pretrained=pretrained_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
633637
# skip P2 because it generates too many anchors (according to their paper)
634-
backbone = resnet_fpn_backbone(
635-
"resnet50",
636-
pretrained_backbone,
637-
returned_layers=[2, 3, 4],
638-
extra_blocks=LastLevelP6P7(256, 256),
639-
trainable_layers=trainable_backbone_layers,
638+
backbone = _resnet_fpn_extractor(
639+
backbone, trainable_backbone_layers, returned_layers=[2, 3, 4], extra_blocks=LastLevelP6P7(256, 256)
640640
)
641641
model = RetinaNet(backbone, num_classes, **kwargs)
642642
if pretrained:

torchvision/models/detection/ssd.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
backbone_urls = {
2424
# We port the features of a VGG16 backbone trained by amdegroot because unlike the one on TorchVision, it uses the
2525
# same input standardization method as the paper. Ref: https://s3.amazonaws.com/amdegroot-models/vgg16_reducedfc.pth
26-
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot.pth"
26+
# Only the `features` weights have proper values, those on the `classifier` module are filled with nans.
27+
"vgg16_features": "https://download.pytorch.org/models/vgg16_features-amdegroot-88682ab5.pth"
2728
}
2829

2930

@@ -519,18 +520,8 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
519520
return OrderedDict([(str(i), v) for i, v in enumerate(output)])
520521

521522

522-
def _vgg_extractor(backbone_name: str, highres: bool, progress: bool, pretrained: bool, trainable_layers: int):
523-
if backbone_name in backbone_urls:
524-
# Use custom backbones more appropriate for SSD
525-
arch = backbone_name.split("_")[0]
526-
backbone = vgg.__dict__[arch](pretrained=False, progress=progress).features
527-
if pretrained:
528-
state_dict = load_state_dict_from_url(backbone_urls[backbone_name], progress=progress)
529-
backbone.load_state_dict(state_dict)
530-
else:
531-
# Use standard backbones from TorchVision
532-
backbone = vgg.__dict__[backbone_name](pretrained=pretrained, progress=progress).features
533-
523+
def _vgg_extractor(backbone: vgg.VGG, highres: bool, trainable_layers: int):
524+
backbone = backbone.features
534525
# Gather the indices of maxpools. These are the locations of output blocks.
535526
stage_indices = [0] + [i for i, b in enumerate(backbone) if isinstance(b, nn.MaxPool2d)][:-1]
536527
num_stages = len(stage_indices)
@@ -609,7 +600,13 @@ def ssd300_vgg16(
609600
# no need to download the backbone if pretrained is set
610601
pretrained_backbone = False
611602

612-
backbone = _vgg_extractor("vgg16_features", False, progress, pretrained_backbone, trainable_backbone_layers)
603+
# Use custom backbones more appropriate for SSD
604+
backbone = vgg.vgg16(pretrained=False, progress=progress)
605+
if pretrained_backbone:
606+
state_dict = load_state_dict_from_url(backbone_urls["vgg16_features"], progress=progress)
607+
backbone.load_state_dict(state_dict)
608+
609+
backbone = _vgg_extractor(backbone, False, trainable_backbone_layers)
613610
anchor_generator = DefaultBoxGenerator(
614611
[[2], [2, 3], [2, 3], [2, 3], [2], [2]],
615612
scales=[0.07, 0.15, 0.33, 0.51, 0.69, 0.87, 1.05],

torchvision/models/detection/ssdlite.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from collections import OrderedDict
33
from functools import partial
4-
from typing import Any, Callable, Dict, List, Optional
4+
from typing import Any, Callable, Dict, List, Optional, Union
55

66
import torch
77
from torch import nn, Tensor
@@ -117,7 +117,6 @@ def __init__(
117117
norm_layer: Callable[..., nn.Module],
118118
width_mult: float = 1.0,
119119
min_depth: int = 16,
120-
**kwargs: Any,
121120
):
122121
super().__init__()
123122

@@ -156,20 +155,11 @@ def forward(self, x: Tensor) -> Dict[str, Tensor]:
156155

157156

158157
def _mobilenet_extractor(
159-
backbone_name: str,
160-
progress: bool,
161-
pretrained: bool,
158+
backbone: Union[mobilenet.MobileNetV2, mobilenet.MobileNetV3],
162159
trainable_layers: int,
163160
norm_layer: Callable[..., nn.Module],
164-
**kwargs: Any,
165161
):
166-
backbone = mobilenet.__dict__[backbone_name](
167-
pretrained=pretrained, progress=progress, norm_layer=norm_layer, **kwargs
168-
).features
169-
if not pretrained:
170-
# Change the default initialization scheme if not pretrained
171-
_normal_init(backbone)
172-
162+
backbone = backbone.features
173163
# Gather the indices of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
174164
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
175165
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "_is_cn", False)] + [len(backbone) - 1]
@@ -183,7 +173,7 @@ def _mobilenet_extractor(
183173
for parameter in b.parameters():
184174
parameter.requires_grad_(False)
185175

186-
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer, **kwargs)
176+
return SSDLiteFeatureExtractorMobileNet(backbone, stage_indices[-2], norm_layer)
187177

188178

189179
def ssdlite320_mobilenet_v3_large(
@@ -235,14 +225,16 @@ def ssdlite320_mobilenet_v3_large(
235225
if norm_layer is None:
236226
norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.03)
237227

228+
backbone = mobilenet.mobilenet_v3_large(
229+
pretrained=pretrained_backbone, progress=progress, norm_layer=norm_layer, reduced_tail=reduce_tail, **kwargs
230+
)
231+
if not pretrained_backbone:
232+
# Change the default initialization scheme if not pretrained
233+
_normal_init(backbone)
238234
backbone = _mobilenet_extractor(
239-
"mobilenet_v3_large",
240-
progress,
241-
pretrained_backbone,
235+
backbone,
242236
trainable_backbone_layers,
243237
norm_layer,
244-
reduced_tail=reduce_tail,
245-
**kwargs,
246238
)
247239

248240
size = (320, 320)

torchvision/prototype/models/detection/backbone_utils.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

torchvision/prototype/models/detection/faster_rcnn.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
11
import warnings
22
from typing import Any, Optional
33

4-
from ....models.detection.faster_rcnn import FasterRCNN, overwrite_eps, _validate_trainable_layers
4+
from ....models.detection.faster_rcnn import (
5+
_validate_trainable_layers,
6+
_resnet_fpn_extractor,
7+
FasterRCNN,
8+
misc_nn_ops,
9+
overwrite_eps,
10+
)
511
from ...transforms.presets import CocoEval
612
from .._api import Weights, WeightEntry
713
from .._meta import _COCO_CATEGORIES
8-
from ..resnet import ResNet50Weights
9-
from .backbone_utils import resnet_fpn_backbone
14+
from ..resnet import ResNet50Weights, resnet50
1015

1116

1217
__all__ = ["FasterRCNN", "FasterRCNNResNet50FPNWeights", "fasterrcnn_resnet50_fpn"]
@@ -49,7 +54,8 @@ def fasterrcnn_resnet50_fpn(
4954
weights is not None or weights_backbone is not None, trainable_backbone_layers, 5, 3
5055
)
5156

52-
backbone = resnet_fpn_backbone("resnet50", weights_backbone, trainable_layers=trainable_backbone_layers)
57+
backbone = resnet50(weights=weights_backbone, progress=progress, norm_layer=misc_nn_ops.FrozenBatchNorm2d)
58+
backbone = _resnet_fpn_extractor(backbone, trainable_backbone_layers)
5359
model = FasterRCNN(backbone, num_classes=num_classes, **kwargs)
5460

5561
if weights is not None:

0 commit comments

Comments
 (0)