1- from collections import OrderedDict
2-
31import torch
42from torch import nn
53import torch .nn .functional as F
64
7- from torchvision .ops import misc as misc_nn_ops
85from torchvision .ops import MultiScaleRoIAlign
96
107from ._utils import overwrite_eps
1512from .rpn import RPNHead , RegionProposalNetwork
1613from .roi_heads import RoIHeads
1714from .transform import GeneralizedRCNNTransform
18- from .backbone_utils import resnet_fpn_backbone , _validate_resnet_trainable_layers
15+ from .backbone_utils import resnet_fpn_backbone , _validate_trainable_layers , mobilenet_backbone
1916
2017
2118__all__ = [
22- "FasterRCNN" , "fasterrcnn_resnet50_fpn" ,
19+ "FasterRCNN" , "fasterrcnn_resnet50_fpn" , "fasterrcnn_mobilenet_v3_large_fpn"
2320]
2421
2522
@@ -291,6 +288,8 @@ def forward(self, x):
291288model_urls = {
292289 'fasterrcnn_resnet50_fpn_coco' :
293290 'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth' ,
291+ 'fasterrcnn_mobilenet_v3_large_fpn_coco' :
292+ 'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth' ,
294293}
295294
296295
@@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
353352 trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
354353 Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
355354 """
356- # check default parameters and by default set it to 3 if possible
357- trainable_backbone_layers = _validate_resnet_trainable_layers (
358- pretrained or pretrained_backbone , trainable_backbone_layers )
355+ trainable_backbone_layers = _validate_trainable_layers (
356+ pretrained or pretrained_backbone , trainable_backbone_layers , 5 , 3 )
359357
360358 if pretrained :
361359 # no need to download the backbone if pretrained is set
@@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
368366 model .load_state_dict (state_dict )
369367 overwrite_eps (model , 0.0 )
370368 return model
369+
370+
371+ def fasterrcnn_mobilenet_v3_large_fpn (pretrained = False , progress = True , num_classes = 91 , pretrained_backbone = True ,
372+ trainable_backbone_layers = None , min_size = 320 , max_size = 640 , rpn_score_thresh = 0.05 ,
373+ ** kwargs ):
374+ """
375+ Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
376+ to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.
377+
378+ Example::
379+
380+ >>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
381+ >>> model.eval()
382+ >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
383+ >>> predictions = model(x)
384+
385+ Args:
386+ pretrained (bool): If True, returns a model pre-trained on COCO train2017
387+ progress (bool): If True, displays a progress bar of the download to stderr
388+ num_classes (int): number of output classes of the model (including the background)
389+ pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
390+ trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
391+ Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
392+ min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
393+ max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
394+ rpn_score_thresh (float): during inference, only return proposals with a classification score
395+ greater than rpn_score_thresh
396+ """
397+ trainable_backbone_layers = _validate_trainable_layers (
398+ pretrained or pretrained_backbone , trainable_backbone_layers , 6 , 3 )
399+
400+ if pretrained :
401+ pretrained_backbone = False
402+ backbone = mobilenet_backbone ("mobilenet_v3_large" , pretrained_backbone , True ,
403+ trainable_layers = trainable_backbone_layers )
404+
405+ anchor_sizes = ((32 , 64 , 128 , 256 , 512 , ), ) * 3
406+ aspect_ratios = ((0.5 , 1.0 , 2.0 ),) * len (anchor_sizes )
407+
408+ model = FasterRCNN (backbone , num_classes , rpn_anchor_generator = AnchorGenerator (anchor_sizes , aspect_ratios ),
409+ min_size = min_size , max_size = max_size , rpn_score_thresh = rpn_score_thresh , ** kwargs )
410+ if pretrained :
411+ state_dict = load_state_dict_from_url (model_urls ['fasterrcnn_mobilenet_v3_large_fpn_coco' ], progress = progress )
412+ model .load_state_dict (state_dict )
413+ return model
0 commit comments