diff --git a/torchvision/models/mobilenetv3.py b/torchvision/models/mobilenetv3.py index 5689b896d16..fbcd894353f 100644 --- a/torchvision/models/mobilenetv3.py +++ b/torchvision/models/mobilenetv3.py @@ -281,7 +281,7 @@ def _mobilenet_v3_conf( return inverted_residual_setting, last_channel -def _mobilenet_v3_model( +def _mobilenet_v3( arch: str, inverted_residual_setting: List[InvertedResidualConfig], last_channel: int, @@ -309,7 +309,7 @@ def mobilenet_v3_large(pretrained: bool = False, progress: bool = True, **kwargs """ arch = "mobilenet_v3_large" inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> MobileNetV3: @@ -323,4 +323,4 @@ def mobilenet_v3_small(pretrained: bool = False, progress: bool = True, **kwargs """ arch = "mobilenet_v3_small" inverted_residual_setting, last_channel = _mobilenet_v3_conf(arch, **kwargs) - return _mobilenet_v3_model(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) + return _mobilenet_v3(arch, inverted_residual_setting, last_channel, pretrained, progress, **kwargs) diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index 00fcb828d55..ed4cfadfa63 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -3,5 +3,6 @@ from .densenet import * from .vgg import * from .efficientnet import * +from .mobilenetv3 import * from . import detection from . import quantization diff --git a/torchvision/prototype/models/mobilenetv3.py b/torchvision/prototype/models/mobilenetv3.py new file mode 100644 index 00000000000..aa66ef6e314 --- /dev/null +++ b/torchvision/prototype/models/mobilenetv3.py @@ -0,0 +1,90 @@ +import warnings +from functools import partial +from typing import Any, Optional, List + +from torchvision.transforms.functional import InterpolationMode + +from ...models.mobilenetv3 import MobileNetV3, _mobilenet_v3_conf, InvertedResidualConfig +from ..transforms.presets import ImageNetEval +from ._api import Weights, WeightEntry +from ._meta import _IMAGENET_CATEGORIES + + +__all__ = [ + "MobileNetV3", + "MobileNetV3LargeWeights", + "MobileNetV3SmallWeights", + "mobilenet_v3_large", + "mobilenet_v3_small", +] + + +def _mobilenet_v3( + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + weights: Optional[Weights], + progress: bool, + **kwargs: Any, +) -> MobileNetV3: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + + model = MobileNetV3(inverted_residual_setting, last_channel, **kwargs) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model + + +_common_meta = {"size": (224, 224), "categories": _IMAGENET_CATEGORIES, "interpolation": InterpolationMode.BILINEAR} + + +class MobileNetV3LargeWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 74.042, + "acc@5": 91.340, + }, + ) + + +class MobileNetV3SmallWeights(Weights): + ImageNet1K_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + **_common_meta, + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#mobilenetv3-large--small", + "acc@1": 67.668, + "acc@5": 87.402, + }, + ) + + +def mobilenet_v3_large( + weights: Optional[MobileNetV3LargeWeights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = MobileNetV3LargeWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = MobileNetV3LargeWeights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs) + + +def mobilenet_v3_small( + weights: Optional[MobileNetV3SmallWeights] = None, progress: bool = True, **kwargs: Any +) -> MobileNetV3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + weights = MobileNetV3SmallWeights.ImageNet1K_RefV1 if kwargs.pop("pretrained") else None + weights = MobileNetV3SmallWeights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_small", **kwargs) + return _mobilenet_v3(inverted_residual_setting, last_channel, weights, progress, **kwargs)