diff --git a/docs/source/models.rst b/docs/source/models.rst index 66ebf0e211d..9443565bda0 100644 --- a/docs/source/models.rst +++ b/docs/source/models.rst @@ -22,7 +22,8 @@ architectures for image classification: - `Inception`_ v3 - `GoogLeNet`_ - `ShuffleNet`_ v2 -- `MobileNet`_ v2 +- `MobileNetV2`_ +- `MobileNetV3`_ - `ResNeXt`_ - `Wide ResNet`_ - `MNASNet`_ @@ -40,7 +41,9 @@ You can construct a model with random weights by calling its constructor: inception = models.inception_v3() googlenet = models.googlenet() shufflenet = models.shufflenet_v2_x1_0() - mobilenet = models.mobilenet_v2() + mobilenet_v2 = models.mobilenet_v2() + mobilenet_v3_large = models.mobilenet_v3_large() + mobilenet_v3_small = models.mobilenet_v3_small() resnext50_32x4d = models.resnext50_32x4d() wide_resnet50_2 = models.wide_resnet50_2() mnasnet = models.mnasnet1_0() @@ -59,7 +62,8 @@ These can be constructed by passing ``pretrained=True``: inception = models.inception_v3(pretrained=True) googlenet = models.googlenet(pretrained=True) shufflenet = models.shufflenet_v2_x1_0(pretrained=True) - mobilenet = models.mobilenet_v2(pretrained=True) + mobilenet_v2 = models.mobilenet_v2(pretrained=True) + mobilenet_v3_large = models.mobilenet_v3_large(pretrained=True) resnext50_32x4d = models.resnext50_32x4d(pretrained=True) wide_resnet50_2 = models.wide_resnet50_2(pretrained=True) mnasnet = models.mnasnet1_0(pretrained=True) @@ -137,6 +141,7 @@ Inception v3 22.55 6.44 GoogleNet 30.22 10.47 ShuffleNet V2 30.64 11.68 MobileNet V2 28.12 9.71 +MobileNet V3 Large 25.96 8.66 ResNeXt-50-32x4d 22.38 6.30 ResNeXt-101-32x8d 20.69 5.47 Wide ResNet-50-2 21.49 5.91 @@ -153,7 +158,8 @@ MNASNet 1.0 26.49 8.456 .. _Inception: https://arxiv.org/abs/1512.00567 .. _GoogLeNet: https://arxiv.org/abs/1409.4842 .. _ShuffleNet: https://arxiv.org/abs/1807.11164 -.. _MobileNet: https://arxiv.org/abs/1801.04381 +.. _MobileNetV2: https://arxiv.org/abs/1801.04381 +.. _MobileNetV3: https://arxiv.org/abs/1905.02244 .. _ResNeXt: https://arxiv.org/abs/1611.05431 .. _MNASNet: https://arxiv.org/abs/1807.11626 @@ -231,6 +237,12 @@ MobileNet v2 .. autofunction:: mobilenet_v2 +MobileNet v3 +------------- + +.. autofunction:: mobilenet_v3_large +.. autofunction:: mobilenet_v3_small + ResNext ------- @@ -351,6 +363,7 @@ Network box AP mask AP keypoint AP ================================ ======= ======== =========== Faster R-CNN ResNet-50 FPN 37.0 - - RetinaNet ResNet-50 FPN 36.4 - - +RetinaNet MobileNetV3-Large FPN 25.6 - - Mask R-CNN ResNet-50 FPN 37.9 34.6 - ================================ ======= ======== =========== @@ -407,6 +420,7 @@ Network train time (s / it) test time (s / it) memory ============================== =================== ================== =========== Faster R-CNN ResNet-50 FPN 0.2288 0.0590 5.2 RetinaNet ResNet-50 FPN 0.2514 0.0939 4.1 +RetinaNet MobileNetV3-Large FPN 0.0928 0.0547 1.4 Mask R-CNN ResNet-50 FPN 0.2728 0.0903 5.4 Keypoint R-CNN ResNet-50 FPN 0.3789 0.1242 6.8 ============================== =================== ================== =========== @@ -422,6 +436,7 @@ RetinaNet ------------ .. autofunction:: torchvision.models.detection.retinanet_resnet50_fpn +.. autofunction:: torchvision.models.detection.retinanet_mobilenet_v3_large_fpn Mask R-CNN diff --git a/hubconf.py b/hubconf.py index 79c22bd938b..dec4a7fb196 100644 --- a/hubconf.py +++ b/hubconf.py @@ -11,7 +11,8 @@ from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn from torchvision.models.googlenet import googlenet from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0 -from torchvision.models.mobilenet import mobilenet_v2 +from torchvision.models.mobilenetv2 import mobilenet_v2 +from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \ mnasnet1_3 diff --git a/references/classification/README.md b/references/classification/README.md index bd00f2c7dd8..d18ab17bf73 100644 --- a/references/classification/README.md +++ b/references/classification/README.md @@ -53,6 +53,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --lr-step-size 1 --lr-gamma 0.98 ``` + +### MobileNetV3 Large +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --model mobilenet_v3_large --epochs 600 --opt rmsprop --batch-size 128 --lr 0.064\ + --wd 0.00001 --lr-step-size 2 --lr-gamma 0.973 --auto-augment imagenet --random-erase 0.2 +``` + +Then we averaged the parameters of the last 3 checkpoints that improved the Acc@1. See [#3182](https://github.com/pytorch/vision/pull/3182) for details. + ## Mixed precision training Automatic Mixed Precision (AMP) training on GPU for Pytorch can be enabled with the [NVIDIA Apex extension](https://github.com/NVIDIA/apex). diff --git a/references/classification/train.py b/references/classification/train.py index 789bb8134ff..47a7e5955e6 100644 --- a/references/classification/train.py +++ b/references/classification/train.py @@ -79,7 +79,7 @@ def _get_cache_path(filepath): return cache_path -def load_data(traindir, valdir, cache_dataset, distributed): +def load_data(traindir, valdir, args): # Data loading code print("Loading data") normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], @@ -88,20 +88,28 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading training data") st = time.time() cache_path = _get_cache_path(traindir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_train from {}".format(cache_path)) dataset, _ = torch.load(cache_path) else: + trans = [ + transforms.RandomResizedCrop(224), + transforms.RandomHorizontalFlip(), + ] + if args.auto_augment is not None: + aa_policy = transforms.AutoAugmentPolicy(args.auto_augment) + trans.append(transforms.AutoAugment(policy=aa_policy)) + trans.extend([ + transforms.ToTensor(), + normalize, + ]) + if args.random_erase > 0: + trans.append(transforms.RandomErasing(p=args.random_erase)) dataset = torchvision.datasets.ImageFolder( traindir, - transforms.Compose([ - transforms.RandomResizedCrop(224), - transforms.RandomHorizontalFlip(), - transforms.ToTensor(), - normalize, - ])) - if cache_dataset: + transforms.Compose(trans)) + if args.cache_dataset: print("Saving dataset_train to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset, traindir), cache_path) @@ -109,7 +117,7 @@ def load_data(traindir, valdir, cache_dataset, distributed): print("Loading validation data") cache_path = _get_cache_path(valdir) - if cache_dataset and os.path.exists(cache_path): + if args.cache_dataset and os.path.exists(cache_path): # Attention, as the transforms are also cached! print("Loading dataset_test from {}".format(cache_path)) dataset_test, _ = torch.load(cache_path) @@ -122,13 +130,13 @@ def load_data(traindir, valdir, cache_dataset, distributed): transforms.ToTensor(), normalize, ])) - if cache_dataset: + if args.cache_dataset: print("Saving dataset_test to {}".format(cache_path)) utils.mkdir(os.path.dirname(cache_path)) utils.save_on_master((dataset_test, valdir), cache_path) print("Creating data loaders") - if distributed: + if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) else: @@ -155,8 +163,7 @@ def main(args): train_dir = os.path.join(args.data_path, 'train') val_dir = os.path.join(args.data_path, 'val') - dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, - args.cache_dataset, args.distributed) + dataset, dataset_test, train_sampler, test_sampler = load_data(train_dir, val_dir, args) data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, pin_memory=True) @@ -173,8 +180,15 @@ def main(args): criterion = nn.CrossEntropyLoss() - optimizer = torch.optim.SGD( - model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + opt_name = args.opt.lower() + if opt_name == 'sgd': + optimizer = torch.optim.SGD( + model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + elif opt_name == 'rmsprop': + optimizer = torch.optim.RMSprop(model.parameters(), lr=args.lr, momentum=args.momentum, + weight_decay=args.weight_decay, eps=0.0316, alpha=0.9) + else: + raise RuntimeError("Invalid optimizer {}. Only SGD and RMSprop are supported.".format(args.opt)) if args.apex: model, optimizer = amp.initialize(model, optimizer, @@ -238,6 +252,7 @@ def parse_args(): help='number of total epochs to run') parser.add_argument('-j', '--workers', default=16, type=int, metavar='N', help='number of data loading workers (default: 16)') + parser.add_argument('--opt', default='sgd', type=str, help='optimizer') parser.add_argument('--lr', default=0.1, type=float, help='initial learning rate') parser.add_argument('--momentum', default=0.9, type=float, metavar='M', help='momentum') @@ -275,6 +290,8 @@ def parse_args(): help="Use pre-trained models from the modelzoo", action="store_true", ) + parser.add_argument('--auto-augment', default=None, help='auto augment policy (default: None)') + parser.add_argument('--random-erase', default=0.0, type=float, help='random erasing probability (default: 0.0)') # Mixed precision training parameters parser.add_argument('--apex', action='store_true', diff --git a/references/detection/README.md b/references/detection/README.md index f89e8149a71..495a775df19 100644 --- a/references/detection/README.md +++ b/references/detection/README.md @@ -27,7 +27,8 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --lr-steps 16 22 --aspect-ratio-group-factor 3 ``` -### RetinaNet + +### RetinaNet with ResNet50 FPN ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ --dataset coco --model retinanet_resnet50_fpn --epochs 26\ @@ -35,6 +36,16 @@ python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ ``` +### RetinaNet with MobileNetV3 Large FPN +``` +python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ + --dataset coco --model retinanet_mobilenet_v3_large_fpn --epochs 26 --lr-steps 16 22\ + --aspect-ratio-group-factor 3 --lr 0.01 +``` + +Then we averaged the parameters of the last 2 checkpoints that improved the AP. See [#3223](https://github.com/pytorch/vision/pull/3223) for details. + + ### Mask R-CNN ``` python -m torch.distributed.launch --nproc_per_node=8 --use_env train.py\ diff --git a/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl b/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl new file mode 100644 index 00000000000..9691daf18c7 Binary files /dev/null and b/test/expect/ModelTester.test_mobilenet_v3_large_expect.pkl differ diff --git a/test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl b/test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl new file mode 100644 index 00000000000..5373739b53c Binary files /dev/null and b/test/expect/ModelTester.test_mobilenet_v3_small_expect.pkl differ diff --git a/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl b/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl new file mode 100644 index 00000000000..2f7ca93d30b Binary files /dev/null and b/test/expect/ModelTester.test_retinanet_mobilenet_v3_large_fpn_expect.pkl differ diff --git a/test/test_models.py b/test/test_models.py index d40649ffb65..165c2c07184 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -40,6 +40,7 @@ def get_available_video_models(): "maskrcnn_resnet50_fpn": lambda x: x[1], "keypointrcnn_resnet50_fpn": lambda x: x[1], "retinanet_resnet50_fpn": lambda x: x[1], + "retinanet_mobilenet_v3_large_fpn": lambda x: x[1], } @@ -104,7 +105,7 @@ def _test_detection_model(self, name, dev): kwargs = {} if "retinanet" in name: # Reduce the default threshold to ensure the returned boxes are not empty. - kwargs["score_thresh"] = 0.01 + kwargs["score_thresh"] = 0.0099999 model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False, **kwargs) model.eval().to(device=dev) input_shape = (3, 300, 300) @@ -275,16 +276,17 @@ def test_mobilenet_v2_residual_setting(self): out = model(x) self.assertEqual(out.shape[-1], 1000) - def test_mobilenetv2_norm_layer(self): - model = models.__dict__["mobilenet_v2"]() - self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + def test_mobilenet_norm_layer(self): + for name in ["mobilenet_v2", "mobilenet_v3_large", "mobilenet_v3_small"]: + model = models.__dict__[name]() + self.assertTrue(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - def get_gn(num_channels): - return nn.GroupNorm(32, num_channels) + def get_gn(num_channels): + return nn.GroupNorm(32, num_channels) - model = models.__dict__["mobilenet_v2"](norm_layer=get_gn) - self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) - self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) + model = models.__dict__[name](norm_layer=get_gn) + self.assertFalse(any(isinstance(x, nn.BatchNorm2d) for x in model.modules())) + self.assertTrue(any(isinstance(x, nn.GroupNorm) for x in model.modules())) def test_inception_v3_eval(self): # replacement for models.inception_v3(pretrained=True) that does not download weights diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index 6d767971f72..ac551598a96 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -129,13 +129,14 @@ def test_forward_negative_sample_krcnn(self): self.assertEqual(loss_dict["loss_keypoint"], torch.tensor(0.)) def test_forward_negative_sample_retinanet(self): - model = torchvision.models.detection.retinanet_resnet50_fpn( - num_classes=2, min_size=100, max_size=100) + for name in ["retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn"]: + model = torchvision.models.detection.__dict__[name]( + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False) - images, targets = self._make_empty_sample() - loss_dict = model(images, targets) + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) - self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) + self.assertEqual(loss_dict["bbox_regression"], torch.tensor(0.)) if __name__ == '__main__': diff --git a/test/test_models_detection_utils.py b/test/test_models_detection_utils.py index bfb26f24eae..8af5c09b097 100644 --- a/test/test_models_detection_utils.py +++ b/test/test_models_detection_utils.py @@ -36,17 +36,17 @@ def test_resnet_fpn_backbone_frozen_layers(self): def test_validate_resnet_inputs_detection(self): # default number of backbone layers to train - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=None) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=None, max_value=5, default_value=3) self.assertEqual(ret, 3) # can't go beyond 5 with self.assertRaises(AssertionError): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=True, trainable_backbone_layers=6) + ret = backbone_utils._validate_trainable_layers( + pretrained=True, trainable_backbone_layers=6, max_value=5, default_value=3) # if not pretrained, should use all trainable layers and warn with self.assertWarns(UserWarning): - ret = backbone_utils._validate_resnet_trainable_layers( - pretrained=False, trainable_backbone_layers=0) + ret = backbone_utils._validate_trainable_layers( + pretrained=False, trainable_backbone_layers=0, max_value=5, default_value=3) self.assertEqual(ret, 5) def test_transform_copy_targets(self): diff --git a/torchvision/models/detection/backbone_utils.py b/torchvision/models/detection/backbone_utils.py index 746e0ee2f59..b88da647a71 100644 --- a/torchvision/models/detection/backbone_utils.py +++ b/torchvision/models/detection/backbone_utils.py @@ -1,10 +1,10 @@ import warnings -from collections import OrderedDict from torch import nn from torchvision.ops.feature_pyramid_network import FeaturePyramidNetwork, LastLevelMaxPool from torchvision.ops import misc as misc_nn_ops from .._utils import IntermediateLayerGetter +from .. import mobilenet from .. import resnet @@ -108,17 +108,55 @@ def resnet_fpn_backbone( return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) -def _validate_resnet_trainable_layers(pretrained, trainable_backbone_layers): +def _validate_trainable_layers(pretrained, trainable_backbone_layers, max_value, default_value): # dont freeze any layers if pretrained model or backbone is not used if not pretrained: if trainable_backbone_layers is not None: warnings.warn( "Changing trainable_backbone_layers has not effect if " "neither pretrained nor pretrained_backbone have been set to True, " - "falling back to trainable_backbone_layers=5 so that all layers are trainable") - trainable_backbone_layers = 5 - # by default, freeze first 2 blocks following Faster R-CNN + "falling back to trainable_backbone_layers={} so that all layers are trainable".format(max_value)) + trainable_backbone_layers = max_value + + # by default freeze first blocks if trainable_backbone_layers is None: - trainable_backbone_layers = 3 - assert trainable_backbone_layers <= 5 and trainable_backbone_layers >= 0 + trainable_backbone_layers = default_value + assert 0 <= trainable_backbone_layers <= max_value return trainable_backbone_layers + + +def mobilenet_fpn_backbone( + backbone_name, + pretrained, + norm_layer=misc_nn_ops.FrozenBatchNorm2d, + trainable_layers=2, + returned_layers=None, + extra_blocks=None +): + backbone = mobilenet.__dict__[backbone_name](pretrained=pretrained, norm_layer=norm_layer).features + + # Gather the indeces of blocks which are strided. These are the locations of C1, ..., Cn-1 blocks. + # The first and last blocks are always included because they are the C0 (conv1) and Cn. + stage_indeces = [0] + [i for i, b in enumerate(backbone) if getattr(b, "is_strided", False)] + [len(backbone) - 1] + num_stages = len(stage_indeces) + + # find the index of the layer from which we wont freeze + assert 0 <= trainable_layers <= num_stages + freeze_before = num_stages if trainable_layers == 0 else stage_indeces[num_stages - trainable_layers] + + # freeze layers only if pretrained backbone is used + for b in backbone[:freeze_before]: + for parameter in b.parameters(): + parameter.requires_grad_(False) + + if extra_blocks is None: + extra_blocks = LastLevelMaxPool() + + if returned_layers is None: + returned_layers = [num_stages - 2, num_stages - 1] + assert min(returned_layers) >= 0 and max(returned_layers) < num_stages + return_layers = {f'{stage_indeces[k]}': str(v) for v, k in enumerate(returned_layers)} + + in_channels_list = [backbone[stage_indeces[i]].out_channels for i in returned_layers] + out_channels = 256 + return BackboneWithFPN(backbone, return_layers, in_channels_list, out_channels, extra_blocks=extra_blocks) diff --git a/torchvision/models/detection/faster_rcnn.py b/torchvision/models/detection/faster_rcnn.py index 7d896d5ec95..80ccc129f8f 100644 --- a/torchvision/models/detection/faster_rcnn.py +++ b/torchvision/models/detection/faster_rcnn.py @@ -15,7 +15,7 @@ from .rpn import RPNHead, RegionProposalNetwork from .roi_heads import RoIHeads from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -350,8 +350,8 @@ def fasterrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/keypoint_rcnn.py b/torchvision/models/detection/keypoint_rcnn.py index 44df04819ff..4f375f818c0 100644 --- a/torchvision/models/detection/keypoint_rcnn.py +++ b/torchvision/models/detection/keypoint_rcnn.py @@ -7,7 +7,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ @@ -319,8 +319,8 @@ def keypointrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/mask_rcnn.py b/torchvision/models/detection/mask_rcnn.py index 565ef05f4cc..8f982ef02d5 100644 --- a/torchvision/models/detection/mask_rcnn.py +++ b/torchvision/models/detection/mask_rcnn.py @@ -8,7 +8,7 @@ from ..utils import load_state_dict_from_url from .faster_rcnn import FasterRCNN -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers __all__ = [ "MaskRCNN", "maskrcnn_resnet50_fpn", @@ -314,8 +314,8 @@ def maskrcnn_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set diff --git a/torchvision/models/detection/retinanet.py b/torchvision/models/detection/retinanet.py index 9836b3316d8..8c33b74c29c 100644 --- a/torchvision/models/detection/retinanet.py +++ b/torchvision/models/detection/retinanet.py @@ -12,14 +12,14 @@ from . import _utils as det_utils from .anchor_utils import AnchorGenerator from .transform import GeneralizedRCNNTransform -from .backbone_utils import resnet_fpn_backbone, _validate_resnet_trainable_layers +from .backbone_utils import resnet_fpn_backbone, _validate_trainable_layers, mobilenet_fpn_backbone from ...ops.feature_pyramid_network import LastLevelP6P7 from ...ops import sigmoid_focal_loss from ...ops import boxes as box_ops __all__ = [ - "RetinaNet", "retinanet_resnet50_fpn", + "RetinaNet", "retinanet_resnet50_fpn", "retinanet_mobilenet_v3_large_fpn" ] @@ -557,7 +557,10 @@ def forward(self, images, targets=None): return self.eager_outputs(losses, detections) +# TODO: replace with pytorch links model_urls = { + 'retinanet_mobilenet_v3_large_fpn_coco': + 'https://download.pytorch.org/models/retinanet_mobilenet_v3_large_fpn-41c847a4.pth', 'retinanet_resnet50_fpn_coco': 'https://download.pytorch.org/models/retinanet_resnet50_fpn_coco-eeacb38b.pth', } @@ -606,8 +609,8 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, Valid values are between 0 and 5, with 5 meaning all backbone layers are trainable. """ # check default parameters and by default set it to 3 if possible - trainable_backbone_layers = _validate_resnet_trainable_layers( - pretrained or pretrained_backbone, trainable_backbone_layers) + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 5, 3) if pretrained: # no need to download the backbone if pretrained is set @@ -622,3 +625,44 @@ def retinanet_resnet50_fpn(pretrained=False, progress=True, model.load_state_dict(state_dict) overwrite_eps(model, 0.0) return model + + +def retinanet_mobilenet_v3_large_fpn(pretrained=False, progress=True, num_classes=91, pretrained_backbone=True, + trainable_backbone_layers=None, **kwargs): + """ + Constructs a RetinaNet model with a MobileNetV3-Large-FPN backbone. It works similarly + to RetinaNet with ResNet-50-FPN backbone. See `retinanet_resnet50_fpn` for more details. + + Example:: + + >>> model = torchvision.models.detection.retinanet_mobilenet_v3_large_fpn(pretrained=True) + >>> model.eval() + >>> x = [torch.rand(3, 300, 400), torch.rand(3, 500, 400)] + >>> predictions = model(x) + + Args: + pretrained (bool): If True, returns a model pre-trained on COCO train2017 + progress (bool): If True, displays a progress bar of the download to stderr + num_classes (int): number of output classes of the model (including the background) + pretrained_backbone (bool): If True, returns a model with backbone pre-trained on Imagenet + trainable_backbone_layers (int): number of trainable (not frozen) resnet layers starting from final block. + Valid values are between 0 and 6, with 6 meaning all backbone layers are trainable. + """ + # check default parameters and by default set it to 3 if possible + trainable_backbone_layers = _validate_trainable_layers( + pretrained or pretrained_backbone, trainable_backbone_layers, 6, 3) + + if pretrained: + pretrained_backbone = False + backbone = mobilenet_fpn_backbone("mobilenet_v3_large", pretrained_backbone, returned_layers=[4, 5], + trainable_layers=trainable_backbone_layers) + + anchor_sizes = ((128,), (256,), (512,)) + aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes) + + model = RetinaNet(backbone, num_classes, anchor_generator=AnchorGenerator(anchor_sizes, aspect_ratios), **kwargs) + if pretrained: + state_dict = load_state_dict_from_url(model_urls['retinanet_mobilenet_v3_large_fpn_coco'], + progress=progress) + model.load_state_dict(state_dict) + return model diff --git a/torchvision/models/mobilenet.py b/torchvision/models/mobilenet.py index 0c7cc10df5b..4108305d3f5 100644 --- a/torchvision/models/mobilenet.py +++ b/torchvision/models/mobilenet.py @@ -1,3 +1,4 @@ from .mobilenetv2 import MobileNetV2, mobilenet_v2, __all__ as mv2_all +from .mobilenetv3 import MobileNetV3, mobilenet_v3_large, mobilenet_v3_small, __all__ as mv3_all -__all__ = mv2_all +__all__ = mv2_all + mv3_all diff --git a/torchvision/models/mobilenetv2.py b/torchvision/models/mobilenetv2.py index 990429bacf9..12f25ef495c 100644 --- a/torchvision/models/mobilenetv2.py +++ b/torchvision/models/mobilenetv2.py @@ -53,6 +53,7 @@ def __init__( norm_layer(out_planes), activation_layer(inplace=True) ) + self.out_channels = out_planes # necessary for backwards compatibility @@ -90,6 +91,8 @@ def __init__( norm_layer(oup), ]) self.conv = nn.Sequential(*layers) + self.out_channels = oup + self.is_strided = stride > 1 def forward(self, x: Tensor) -> Tensor: if self.use_res_connect: diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py new file mode 100644 index 00000000000..671acbc4a57 --- /dev/null +++ b/torchvision/models/mobilenetv3.py @@ -0,0 +1,281 @@ +import torch + +from functools import partial +from torch import nn, Tensor +from torch.nn import functional as F +from typing import Any, Callable, List, Optional, Sequence + +from torchvision.models.utils import load_state_dict_from_url +from torchvision.models.mobilenetv2 import _make_divisible, ConvBNActivation + + +__all__ = ["MobileNetV3", "mobilenet_v3_large", "mobilenet_v3_small"] + + +# TODO: add pretrained +model_urls = { + "mobilenet_v3_large": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + "mobilenet_v3_small": None, +} + + +class Identity(nn.Module): + + def __init__(self, inplace: bool = False): + super().__init__() + self.inplace = inplace + + def forward(self, input: Tensor) -> Tensor: + return input + + +class SqueezeExcitation(nn.Module): + + def __init__(self, input_channels: int, squeeze_factor: int = 4): + super().__init__() + squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8) + self.fc1 = nn.Conv2d(input_channels, squeeze_channels, 1) + self.fc2 = nn.Conv2d(squeeze_channels, input_channels, 1) + + def forward(self, input: Tensor) -> Tensor: + scale = F.adaptive_avg_pool2d(input, 1) + scale = self.fc1(scale) + scale = F.relu(scale, inplace=True) + scale = self.fc2(scale) + scale = F.hardsigmoid(scale, inplace=True) + return scale * input + + +class InvertedResidualConfig: + + def __init__(self, input_channels: int, kernel: int, expanded_channels: int, out_channels: int, use_se: bool, + activation: str, stride: int, width_mult: float): + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.kernel = kernel + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + + +class InvertedResidual(nn.Module): + + def __init__(self, cnf: InvertedResidualConfig, norm_layer: Callable[..., nn.Module]): + super().__init__() + if not (1 <= cnf.stride <= 2): + raise ValueError('illegal stride value') + + self.use_res_connect = cnf.stride == 1 and cnf.input_channels == cnf.out_channels + + layers: List[nn.Module] = [] + activation_layer = nn.Hardswish if cnf.use_hs else nn.ReLU + + # expand + if cnf.expanded_channels != cnf.input_channels: + layers.append(ConvBNActivation(cnf.input_channels, cnf.expanded_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=activation_layer)) + + # depthwise + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.expanded_channels, kernel_size=cnf.kernel, + stride=cnf.stride, groups=cnf.expanded_channels, norm_layer=norm_layer, + activation_layer=activation_layer)) + if cnf.use_se: + layers.append(SqueezeExcitation(cnf.expanded_channels)) + + # project + layers.append(ConvBNActivation(cnf.expanded_channels, cnf.out_channels, kernel_size=1, norm_layer=norm_layer, + activation_layer=Identity)) + + self.block = nn.Sequential(*layers) + self.out_channels = cnf.out_channels + self.is_strided = cnf.stride > 1 + + def forward(self, input: Tensor) -> Tensor: + result = self.block(input) + if self.use_res_connect: + result += input + return result + + +class MobileNetV3(nn.Module): + + def __init__( + self, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + num_classes: int = 1000, + block: Optional[Callable[..., nn.Module]] = None, + norm_layer: Optional[Callable[..., nn.Module]] = None + ) -> None: + """ + MobileNet V3 main class + + Args: + inverted_residual_setting (List[InvertedResidualConfig]): Network structure + last_channel (int): The number of channels on the penultimate layer + num_classes (int): Number of classes + block (Optional[Callable[..., nn.Module]]): Module specifying inverted residual building block for mobilenet + norm_layer (Optional[Callable[..., nn.Module]]): Module specifying the normalization layer to use + """ + super().__init__() + + if not inverted_residual_setting: + raise ValueError("The inverted_residual_setting should not be empty") + elif not (isinstance(inverted_residual_setting, Sequence) and + all([isinstance(s, InvertedResidualConfig) for s in inverted_residual_setting])): + raise TypeError("The inverted_residual_setting should be List[InvertedResidualConfig]") + + if block is None: + block = InvertedResidual + + if norm_layer is None: + norm_layer = partial(nn.BatchNorm2d, eps=0.001, momentum=0.01) + + layers: List[nn.Module] = [] + + # building first layer + firstconv_output_channels = inverted_residual_setting[0].input_channels + layers.append(ConvBNActivation(3, firstconv_output_channels, kernel_size=3, stride=2, norm_layer=norm_layer, + activation_layer=nn.Hardswish)) + + # building inverted residual blocks + for cnf in inverted_residual_setting: + layers.append(block(cnf, norm_layer)) + + # building last several layers + lastconv_input_channels = inverted_residual_setting[-1].out_channels + lastconv_output_channels = 6 * lastconv_input_channels + layers.append(ConvBNActivation(lastconv_input_channels, lastconv_output_channels, kernel_size=1, + norm_layer=norm_layer, activation_layer=nn.Hardswish)) + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d(1) + self.classifier = nn.Sequential( + nn.Linear(lastconv_output_channels, last_channel), + nn.Hardswish(inplace=True), + nn.Dropout(p=0.2, inplace=True), + nn.Linear(last_channel, num_classes), + ) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out') + if m.bias is not None: + nn.init.zeros_(m.bias) + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(m.weight) + nn.init.zeros_(m.bias) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.zeros_(m.bias) + + def _forward_impl(self, x: Tensor) -> Tensor: + x = self.features(x) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + x = self.classifier(x) + + return x + + def forward(self, x: Tensor) -> Tensor: + return self._forward_impl(x) + + +def _mobilenet_v3( + arch: str, + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + pretrained: bool, + progress: bool, + **kwargs: Any +): + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + if pretrained and model_urls[arch] is not None: + state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) + model.load_state_dict(state_dict) + return model + + +def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: + """ + Constructs a large MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. + """ + width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + + reduce_divider = 2 if reduced_tail else 1 + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, False, "RE", 1), + bneck_conf(16, 3, 64, 24, False, "RE", 2), # C1 + bneck_conf(24, 3, 72, 24, False, "RE", 1), + bneck_conf(24, 5, 72, 40, True, "RE", 2), # C2 + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 5, 120, 40, True, "RE", 1), + bneck_conf(40, 3, 240, 80, False, "HS", 2), # C3 + bneck_conf(80, 3, 200, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 184, 80, False, "HS", 1), + bneck_conf(80, 3, 480, 112, True, "HS", 1), + bneck_conf(112, 3, 672, 112, True, "HS", 1), + bneck_conf(112, 5, 672, 160 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + bneck_conf(160 // reduce_divider, 5, 960 // reduce_divider, 160 // reduce_divider, True, "HS", 1), + ] + last_channel = adjust_channels(1280 // reduce_divider) # C5 + + return _mobilenet_v3("mobilenet_v3_large", inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + + +def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, reduced_tail: bool = False, + **kwargs: Any) -> MobileNetV3: + """ + Constructs a small MobileNetV3 architecture from + `"Searching for MobileNetV3" `_. + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + progress (bool): If True, displays a progress bar of the download to stderr + reduced_tail (bool): If True, reduces the channel counts of all feature layers + between C4 and C5 by 2. It is used to reduce the channel redundancy in the + backbone for Detection and Segmentation. + """ + width_mult = 1.0 + bneck_conf = partial(InvertedResidualConfig, width_mult=width_mult) + adjust_channels = partial(InvertedResidualConfig.adjust_channels, width_mult=width_mult) + + reduce_divider = 2 if reduced_tail else 1 + + inverted_residual_setting = [ + bneck_conf(16, 3, 16, 16, True, "RE", 2), # C1 + bneck_conf(16, 3, 72, 24, False, "RE", 2), # C2 + bneck_conf(24, 3, 88, 24, False, "RE", 1), + bneck_conf(24, 5, 96, 40, True, "HS", 2), # C3 + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 240, 40, True, "HS", 1), + bneck_conf(40, 5, 120, 48, True, "HS", 1), + bneck_conf(48, 5, 144, 48, True, "HS", 1), + bneck_conf(48, 5, 288, 96 // reduce_divider, True, "HS", 2), # C4 + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + bneck_conf(96 // reduce_divider, 5, 576 // reduce_divider, 96 // reduce_divider, True, "HS", 1), + ] + last_channel = adjust_channels(1024 // reduce_divider) # C5 + + return _mobilenet_v3("mobilenet_v3_small", inverted_residual_setting, last_channel, pretrained, progress, **kwargs)