12
12
from . import _utils as det_utils
13
13
from .anchor_utils import AnchorGenerator
14
14
from .transform import GeneralizedRCNNTransform
15
- from .backbone_utils import resnet_fpn_backbone , _validate_resnet_trainable_layers , mobilenet_fpn_backbone
15
+ from .backbone_utils import resnet_fpn_backbone , _validate_trainable_layers , mobilenet_fpn_backbone
16
16
from ...ops .feature_pyramid_network import LastLevelP6P7
17
17
from ...ops import sigmoid_focal_loss
18
18
from ...ops import boxes as box_ops
@@ -556,6 +556,7 @@ def forward(self, images, targets=None):
556
556
return losses , detections
557
557
return self .eager_outputs (losses , detections )
558
558
559
+
559
560
# TODO: replace with pytorch links
560
561
model_urls = {
561
562
'retinanet_mobilenet_v3_large_fpn_coco' :
@@ -608,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
608
609
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
609
610
"""
610
611
# check default parameters and by default set it to 3 if possible
611
- trainable_backbone_layers = _validate_resnet_trainable_layers (
612
- pretrained or pretrained_backbone , trainable_backbone_layers )
612
+ trainable_backbone_layers = _validate_trainable_layers (
613
+ pretrained or pretrained_backbone , trainable_backbone_layers , 5 , 3 )
613
614
614
615
if pretrained :
615
616
# no need to download the backbone if pretrained is set
@@ -626,8 +627,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True,
626
627
return model
627
628
628
629
629
- def retinanet_mobilenet_v3_large_fpn (pretrained = False , progress = True ,
630
- num_classes = 91 , pretrained_backbone = True , ** kwargs ):
630
+ def retinanet_mobilenet_v3_large_fpn (pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True ,
631
+ trainable_backbone_layers = None , ** kwargs ):
631
632
"""
632
633
Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly
633
634
to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details.
@@ -642,10 +643,19 @@ def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True,
642
643
Args:
643
644
pretrained (bool): If True, returns a model pre-trained on COCO train2017
644
645
progress (bool): If True, displays a progress bar of the download to stderr
646
+ num_classes (int): number of output classes of the model (including the background)
647
+ pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
648
+ trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
649
+ Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
645
650
"""
651
+ # check default parameters and by default set it to 3 if possible
652
+ trainable_backbone_layers = _validate_trainable_layers (
653
+ pretrained or pretrained_backbone , trainable_backbone_layers , 6 , 2 )
654
+
646
655
if pretrained :
647
656
pretrained_backbone = False
648
- backbone = mobilenet_fpn_backbone ("mobilenet_v3_large" , pretrained_backbone , returned_layers = [4 , 5 ])
657
+ backbone = mobilenet_fpn_backbone ("mobilenet_v3_large" , pretrained_backbone , returned_layers = [4 , 5 ],
658
+ trainable_layers = trainable_backbone_layers )
649
659
650
660
anchor_sizes = ((128 ,), (256 ,), (512 ,))
651
661
aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
0 commit comments