Skip to content

Commit e211494

Browse files
committed
making _validate_resnet_trainable_layers generic and using it in mobilenet
1 parent 6dc1724 commit e211494

File tree

5 files changed

+32
-21
lines changed

5 files changed

+32
-21
lines changed

torchvision/models/detection/backbone_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -108,19 +108,20 @@ def resnet_fpn_backbone(
108108
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
109109

110110

111-
def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
111+
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
112112
# dont freeze any layers if pretrained model or backbone is not used
113113
if not pretrained:
114114
if trainable_backbone_layers is not None:
115115
warnings.warn(
116116
"Changing trainable_backbone_layers has not effect if "
117117
"neither pretrained nor pretrained_backbone have been set to True, "
118-
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
119-
trainable_backbone_layers = 5
120-
# by default, freeze first 2 blocks following Faster R-CNN
118+
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
119+
trainable_backbone_layers = max_value
120+
121+
# by default freeze first blocks
121122
if trainable_backbone_layers is None:
122-
trainable_backbone_layers = 3
123-
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
123+
trainable_backbone_layers = default_value
124+
assert 0 <= trainable_backbone_layers <= max_value
124125
return trainable_backbone_layers
125126

126127

torchvision/models/detection/faster_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .rpn import RPNHead, RegionProposalNetwork
1616
from .roi_heads import RoIHeads
1717
from .transform import GeneralizedRCNNTransform
18-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
18+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1919

2020

2121
__all__ = [
@@ -350,8 +350,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
350350
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
351351
"""
352352
# check default parameters and by default set it to 3 if possible
353-
trainable_backbone_layers = _validate_resnet_trainable_layers(
354-
pretrained or pretrained_backbone, trainable_backbone_layers)
353+
trainable_backbone_layers = _validate_trainable_layers(
354+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
355355

356356
if pretrained:
357357
# no need to download the backbone if pretrained is set

torchvision/models/detection/keypoint_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ..utils import load_state_dict_from_url
88

99
from .faster_rcnn import FasterRCNN
10-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
10+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1111

1212

1313
__all__ = [
@@ -319,8 +319,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
319319
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
320320
"""
321321
# check default parameters and by default set it to 3 if possible
322-
trainable_backbone_layers = _validate_resnet_trainable_layers(
323-
pretrained or pretrained_backbone, trainable_backbone_layers)
322+
trainable_backbone_layers = _validate_trainable_layers(
323+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
324324

325325
if pretrained:
326326
# no need to download the backbone if pretrained is set

torchvision/models/detection/mask_rcnn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ..utils import load_state_dict_from_url
99

1010
from .faster_rcnn import FasterRCNN
11-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
11+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers
1212

1313
__all__ = [
1414
"MaskRCNN", "maskrcnn_resnet50_fpn",
@@ -314,8 +314,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
314314
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
315315
"""
316316
# check default parameters and by default set it to 3 if possible
317-
trainable_backbone_layers = _validate_resnet_trainable_layers(
318-
pretrained or pretrained_backbone, trainable_backbone_layers)
317+
trainable_backbone_layers = _validate_trainable_layers(
318+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
319319

320320
if pretrained:
321321
# no need to download the backbone if pretrained is set

torchvision/models/detection/retinanet.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from . import _utils as det_utils
1313
from .anchor_utils import AnchorGenerator
1414
from .transform import GeneralizedRCNNTransform
15-
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers, mobilenet_fpn_backbone
15+
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_fpn_backbone
1616
from ...ops.feature_pyramid_network import LastLevelP6P7
1717
from ...ops import sigmoid_focal_loss
1818
from ...ops import boxes as box_ops
@@ -556,6 +556,7 @@ def forward(self, images, targets=None):
556556
return losses, detections
557557
return self.eager_outputs(losses, detections)
558558

559+
559560
# TODO: replace with pytorch links
560561
model_urls = {
561562
'retinanet_mobilenet_v3_large_fpn_coco':
@@ -608,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
608609
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
609610
"""
610611
# check default parameters and by default set it to 3 if possible
611-
trainable_backbone_layers = _validate_resnet_trainable_layers(
612-
pretrained or pretrained_backbone, trainable_backbone_layers)
612+
trainable_backbone_layers = _validate_trainable_layers(
613+
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)
613614

614615
if pretrained:
615616
# no need to download the backbone if pretrained is set
@@ -626,8 +627,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
626627
return model
627628

628629

629-
def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True,
630-
num_classes=91, pretrained_backbone=True, **kwargs):
630+
def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
631+
trainable_backbone_layers=None, **kwargs):
631632
"""
632633
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
633634
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
@@ -642,10 +643,19 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True,
642643
Args:
643644
pretrained (bool): If True, returns a model pre-trained on COCO train2017
644645
progress (bool): If True, displays a progress bar of the download to stderr
646+
num_classes (int): number of output classes of the model (including the background)
647+
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
648+
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
649+
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
645650
"""
651+
# check default parameters and by default set it to 3 if possible
652+
trainable_backbone_layers = _validate_trainable_layers(
653+
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 2)
654+
646655
if pretrained:
647656
pretrained_backbone = False
648-
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5])
657+
backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5],
658+
trainable_layers=trainable_backbone_layers)
649659

650660
anchor_sizes = ((128,), (256,), (512,))
651661
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

0 commit comments

Comments
 (0)