Skip to content

Add MobileNetV3 architecture for Detection #3253

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jan 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions docs/source/models.rst
Original file line number Diff line number Diff line change
Expand Up @@ -358,13 +358,14 @@ models return the predictions of the following classes:
Here are the summary of the accuracies for the models trained on
the instances set of COCO train2017 and evaluated on COCO val2017.

================================ ======= ======== ===========
Network box AP mask AP keypoint AP
================================ ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================ ======= ======== ===========
================================== ======= ======== ===========
Network box AP mask AP keypoint AP
================================== ======= ======== ===========
Faster R-CNN ResNet-50 FPN 37.0 - -
Faster R-CNN MobileNetV3-Large FPN 23.0 - -
RetinaNet ResNet-50 FPN 36.4 - -
Mask R-CNN ResNet-50 FPN 37.9 34.6 -
================================== ======= ======== ===========

For person keypoint detection, the accuracies for the pre-trained
models are as follows
Expand Down Expand Up @@ -414,20 +415,22 @@ For test time, we report the time for the model evaluation and postprocessing
(including mask pasting in image), but not the time for computing the
precision-recall.

============================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
============================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
============================== =================== ================== ===========
================================== =================== ================== ===========
Network train time (s / it) test time (s / it) memory (GB)
================================== =================== ================== ===========
Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2
Faster R-CNN MobileNetV3-Large FPN 0.0978 0.0376 0.6
RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1
Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4
Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8
================================== =================== ================== ===========


Faster R-CNN
------------

.. autofunction:: torchvision.models.detection.fasterrcnn_resnet50_fpn
.. autofunction:: torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn


RetinaNet
Expand Down
9 changes: 8 additions & 1 deletion references/detection/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,20 @@ You must modify the following flags:

Except otherwise noted, all models have been trained on 8x V100 GPUs.

### Faster R-CNN
### Faster R-CNN ResNet-50 FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### Faster R-CNN MobileNetV3-Large FPN
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
--dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
--lr-steps 16 22 --aspect-ratio-group-factor 3
```

### RetinaNet
```
python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\
Expand Down
10 changes: 8 additions & 2 deletions references/detection/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,12 @@ def main(args):
collate_fn=utils.collate_fn)

print("Creating model")
kwargs = {}
kwargs = {
"trainable_backbone_layers": args.trainable_backbone_layers
}
if "rcnn" in args.model:
kwargs["rpn_score_thresh"] = 0.0
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
model = torchvision.models.detection.__dict__[args.model](num_classes=num_classes, pretrained=args.pretrained,
**kwargs)
model.to(device)
Expand Down Expand Up @@ -177,6 +180,9 @@ def main(args):
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, help='start epoch')
parser.add_argument('--aspect-ratio-group-factor', default=3, type=int)
parser.add_argument('--rpn-score-thresh', default=None, type=float, help='rpn score threshold for faster-rcnn')
parser.add_argument('--trainable-backbone-layers', default=None, type=int,
help='number of trainable layers of backbone')
parser.add_argument(
"--test-only",
dest="test_only",
Expand Down
Binary file not shown.
3 changes: 3 additions & 0 deletions test/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def get_available_video_models():
'googlenet': lambda x: x.logits,
'inception_v3': lambda x: x.logits,
"fasterrcnn_resnet50_fpn": lambda x: x[1],
"fasterrcnn_mobilenet_v3_large_fpn": lambda x: x[1],
"maskrcnn_resnet50_fpn": lambda x: x[1],
"keypointrcnn_resnet50_fpn": lambda x: x[1],
"retinanet_resnet50_fpn": lambda x: x[1],
Expand Down Expand Up @@ -105,6 +106,8 @@ def _test_detection_model(self, name, dev):
if "retinanet" in name:
# Reduce the default threshold to ensure the returned boxes are not empty.
kwargs["score_thresh"] = 0.01
elif "fasterrcnn_mobilenet" in name:
kwargs["box_score_thresh"] = 0.02076
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs)
model.eval().to(device=dev)
input_shape = (3, 300, 300)
Expand Down
15 changes: 8 additions & 7 deletions test/test_models_detection_negative_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,14 +97,15 @@ def test_assign_targets_to_proposals(self):
self.assertEqual(labels[0].dtype, torch.int64)

def test_forward_negative_sample_frcnn(self):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
for name in ["fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"]:
model = torchvision.models.detection.__dict__[name](
num_classes=2, min_size=100, max_size=100)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
images, targets = self._make_empty_sample()
loss_dict = model(images, targets)

self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_box_reg"], torch.tensor(0.))
self.assertEqual(loss_dict["loss_rpn_box_reg"], torch.tensor(0.))

def test_forward_negative_sample_mrcnn(self):
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
Expand All @@ -130,7 +131,7 @@ def test_forward_negative_sample_krcnn(self):

def test_forward_negative_sample_retinanet(self):
model = torchvision.models.detection.retinanet_resnet50_fpn(
num_classes=2, min_size=100, max_size=100)
num_classes=2, min_size=100, max_size=100, pretrained_backbone=False)

images, targets = self._make_empty_sample()
loss_dict = model(images, targets)
Expand Down
12 changes: 6 additions & 6 deletions test/test_models_detection_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self):

def test_validate_resnet_inputs_detection(self):
# default number of backbone layers to train
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=True, trainable_backbone_layers=None)
ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3)
self.assertEqual(ret, 3)
# can't go beyond 5
with self.assertRaises(AssertionError):
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=True, trainable_backbone_layers=6)
ret = backbone_utils._validate_trainable_layers(
pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3)
# if not pretrained, should use all trainable layers and warn
with self.assertWarns(UserWarning):
ret = backbone_utils._validate_resnet_trainable_layers(
pretrained=False, trainable_backbone_layers=0)
ret = backbone_utils._validate_trainable_layers(
pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3)
self.assertEqual(ret, 5)

def test_transform_copy_targets(self):
Expand Down
62 changes: 55 additions & 7 deletions torchvision/models/detection/backbone_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import warnings
from collections import OrderedDict
from torch import nn
from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool

from torchvision.ops import misc as misc_nn_ops
from .._utils import IntermediateLayerGetter
from .. import mobilenet
from .. import resnet


Expand Down Expand Up @@ -108,17 +108,65 @@ def resnet_fpn_backbone(
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)


def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers):
def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value):
# dont freeze any layers if pretrained model or backbone is not used
if not pretrained:
if trainable_backbone_layers is not None:
warnings.warn(
"Changing trainable_backbone_layers has not effect if "
"neither pretrained nor pretrained_backbone have been set to True, "
"falling back to trainable_backbone_layers=5 so that all layers are trainable")
trainable_backbone_layers = 5
# by default, freeze first 2 blocks following Faster R-CNN
"falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value))
trainable_backbone_layers = max_value

# by default freeze first blocks
if trainable_backbone_layers is None:
trainable_backbone_layers = 3
assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0
trainable_backbone_layers = default_value
assert 0 <= trainable_backbone_layers <= max_value
return trainable_backbone_layers


def mobilenet_backbone(
backbone_name,
pretrained,
fpn,
norm_layer=misc_nn_ops.FrozenBatchNorm2d,
trainable_layers=2,
returned_layers=None,
extra_blocks=None
):
backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features

# Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks.
# The first and last blocks are always included because they are the C0 (conv1) and Cn.
stage_indices = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1]
num_stages = len(stage_indices)

# find the index of the layer from which we wont freeze
assert 0 <= trainable_layers <= num_stages
freeze_before = num_stages if trainable_layers == 0 else stage_indices[num_stages - trainable_layers]

# freeze layers only if pretrained backbone is used
for b in backbone[:freeze_before]:
for parameter in b.parameters():
parameter.requires_grad_(False)

out_channels = 256
if fpn:
if extra_blocks is None:
extra_blocks = LastLevelMaxPool()

if returned_layers is None:
returned_layers = [num_stages - 2, num_stages - 1]
assert min(returned_layers) >= 0 and max(returned_layers) < num_stages
return_layers = {f'{stage_indices[k]}': str(v) for v, k in enumerate(returned_layers)}

in_channels_list = [backbone[stage_indices[i]].out_channels for i in returned_layers]
return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks)
else:
m = nn.Sequential(
backbone,
# depthwise linear combination of channels to reduce their size
nn.Conv2d(backbone[-1].out_channels, out_channels, 1),
)
m.out_channels = out_channels
return m
59 changes: 51 additions & 8 deletions torchvision/models/detection/faster_rcnn.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from collections import OrderedDict

import torch
from torch import nn
import torch.nn.functional as F

from torchvision.ops import misc as misc_nn_ops
from torchvision.ops import MultiScaleRoIAlign

from ._utils import overwrite_eps
Expand All @@ -15,11 +12,11 @@
from .rpn import RPNHead, RegionProposalNetwork
from .roi_heads import RoIHeads
from .transform import GeneralizedRCNNTransform
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_backbone


__all__ = [
"FasterRCNN", "fasterrcnn_resnet50_fpn",
"FasterRCNN", "fasterrcnn_resnet50_fpn", "fasterrcnn_mobilenet_v3_large_fpn"
]


Expand Down Expand Up @@ -291,6 +288,8 @@ def forward(self, x):
model_urls = {
'fasterrcnn_resnet50_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_resnet50_fpn_coco-258fb6c6.pth',
'fasterrcnn_mobilenet_v3_large_fpn_coco':
'https://download.pytorch.org/models/fasterrcnn_mobilenet_v3_large_fpn-907ea3f9.pth',
}


Expand Down Expand Up @@ -353,9 +352,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand All @@ -368,3 +366,48 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True,
model.load_state_dict(state_dict)
overwrite_eps(model, 0.0)
return model


def fasterrcnn_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True,
trainable_backbone_layers=None, min_size=320, max_size=640, rpn_score_thresh=0.05,
**kwargs):
"""
Constructs a Faster R-CNN model with a MobileNetV3-Large FPN backbone. It works similarly
to Faster R-CNN with ResNet-50 FPN backbone. See `fasterrcnn_resnet50_fpn` for more details.

Example::

>>> model = torchvision.models.detection.fasterrcnn_mobilenet_v3_large_fpn(pretrained=True)
>>> model.eval()
>>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)]
>>> predictions = model(x)

Args:
pretrained (bool): If True, returns a model pre-trained on COCO train2017
progress (bool): If True, displays a progress bar of the download to stderr
num_classes (int): number of output classes of the model (including the background)
pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable.
min_size (int): minimum size of the image to be rescaled before feeding it to the backbone
max_size (int): maximum size of the image to be rescaled before feeding it to the backbone
rpn_score_thresh (float): during inference, only return proposals with a classification score
greater than rpn_score_thresh
"""
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3)

if pretrained:
pretrained_backbone = False
backbone = mobilenet_backbone("mobilenet_v3_large", pretrained_backbone, True,
trainable_layers=trainable_backbone_layers)

anchor_sizes = ((32, 64, 128, 256, 512, ), ) * 3
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)

model = FasterRCNN(backbone, num_classes, rpn_anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios),
min_size=min_size, max_size=max_size, rpn_score_thresh=rpn_score_thresh, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['fasterrcnn_mobilenet_v3_large_fpn_coco'], progress=progress)
model.load_state_dict(state_dict)
return model
7 changes: 3 additions & 4 deletions torchvision/models/detection/keypoint_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers


__all__ = [
Expand Down Expand Up @@ -322,9 +322,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand Down
7 changes: 3 additions & 4 deletions torchvision/models/detection/mask_rcnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..utils import load_state_dict_from_url

from .faster_rcnn import FasterRCNN
from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers
from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers

__all__ = [
"MaskRCNN", "maskrcnn_resnet50_fpn",
Expand Down Expand Up @@ -317,9 +317,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True,
trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block.
Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable.
"""
# check default parameters and by default set it to 3 if possible
trainable_backbone_layers = _validate_resnet_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers)
trainable_backbone_layers = _validate_trainable_layers(
pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3)

if pretrained:
# no need to download the backbone if pretrained is set
Expand Down
Loading