diff --git a/torchvision/prototype/models/__init__.py b/torchvision/prototype/models/__init__.py index ef0288f60b0..5077b7fd178 100644 --- a/torchvision/prototype/models/__init__.py +++ b/torchvision/prototype/models/__init__.py @@ -4,8 +4,7 @@ from .googlenet import * from .inception import * from .mnasnet import * -from .mobilenetv2 import * -from .mobilenetv3 import * +from .mobilenet import * from .regnet import * from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/prototype/models/mobilenet.py b/torchvision/prototype/models/mobilenet.py new file mode 100644 index 00000000000..0a270d14d3a --- /dev/null +++ b/torchvision/prototype/models/mobilenet.py @@ -0,0 +1,6 @@ +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all + +__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/quantization/__init__.py b/torchvision/prototype/models/quantization/__init__.py index c8a3f886d8d..da8bbba3567 100644 --- a/torchvision/prototype/models/quantization/__init__.py +++ b/torchvision/prototype/models/quantization/__init__.py @@ -1,4 +1,5 @@ from .googlenet import * from .inception import * +from .mobilenet import * from .resnet import * from .shufflenetv2 import * diff --git a/torchvision/prototype/models/quantization/mobilenet.py b/torchvision/prototype/models/quantization/mobilenet.py new file mode 100644 index 00000000000..0a270d14d3a --- /dev/null +++ b/torchvision/prototype/models/quantization/mobilenet.py @@ -0,0 +1,6 @@ +from .mobilenetv2 import * # noqa: F401, F403 +from .mobilenetv3 import * # noqa: F401, F403 +from .mobilenetv2 import __all__ as mv2_all +from .mobilenetv3 import __all__ as mv3_all + +__all__ = mv2_all + mv3_all diff --git a/torchvision/prototype/models/quantization/mobilenetv2.py b/torchvision/prototype/models/quantization/mobilenetv2.py new file mode 100644 index 00000000000..578b3e5e37f --- /dev/null +++ b/torchvision/prototype/models/quantization/mobilenetv2.py @@ -0,0 +1,80 @@ +import warnings +from functools import partial +from typing import Any, Optional, Union + +from torchvision.transforms.functional import InterpolationMode + +from ....models.quantization.mobilenetv2 import ( + QuantizableInvertedResidual, + QuantizableMobileNetV2, + _replace_relu, + quantize_model, +) +from ...transforms.presets import ImageNetEval +from .._api import Weights, WeightEntry +from .._meta import _IMAGENET_CATEGORIES +from ..mobilenetv2 import MobileNetV2Weights + + +__all__ = [ + "QuantizableMobileNetV2", + "QuantizedMobileNetV2Weights", + "mobilenet_v2", +] + + +class QuantizedMobileNetV2Weights(Weights): + ImageNet1K_QNNPACK_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", + "unquantized": MobileNetV2Weights.ImageNet1K_RefV1, + "acc@1": 71.658, + "acc@5": 90.150, + }, + ) + + +def mobilenet_v2( + weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV2: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + if kwargs.pop("pretrained"): + weights = ( + QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 + if quantize + else MobileNetV2Weights.ImageNet1K_RefV1 + ) + else: + weights = None + + if quantize: + weights = QuantizedMobileNetV2Weights.verify(weights) + else: + weights = MobileNetV2Weights.verify(weights) + + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + if "backend" in weights.meta: + kwargs["backend"] = weights.meta["backend"] + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + if quantize: + quantize_model(model, backend) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + return model diff --git a/torchvision/prototype/models/quantization/mobilenetv3.py b/torchvision/prototype/models/quantization/mobilenetv3.py new file mode 100644 index 00000000000..924ee91852e --- /dev/null +++ b/torchvision/prototype/models/quantization/mobilenetv3.py @@ -0,0 +1,100 @@ +import warnings +from functools import partial +from typing import Any, List, Optional, Union + +import torch +from torchvision.transforms.functional import InterpolationMode + +from ....models.quantization.mobilenetv3 import ( + InvertedResidualConfig, + QuantizableInvertedResidual, + QuantizableMobileNetV3, + _replace_relu, +) +from ...transforms.presets import ImageNetEval +from .._api import Weights, WeightEntry +from .._meta import _IMAGENET_CATEGORIES +from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf + + +__all__ = [ + "QuantizableMobileNetV3", + "QuantizedMobileNetV3LargeWeights", + "mobilenet_v3_large", +] + + +def _mobilenet_v3_model( + inverted_residual_setting: List[InvertedResidualConfig], + last_channel: int, + weights: Optional[Weights], + progress: bool, + quantize: bool, + **kwargs: Any, +) -> QuantizableMobileNetV3: + if weights is not None: + kwargs["num_classes"] = len(weights.meta["categories"]) + if "backend" in weights.meta: + kwargs["backend"] = weights.meta["backend"] + backend = kwargs.pop("backend", "qnnpack") + + model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) + _replace_relu(model) + + if quantize: + model.fuse_model() + model.qconfig = torch.quantization.get_default_qat_qconfig(backend) + torch.quantization.prepare_qat(model, inplace=True) + + if weights is not None: + model.load_state_dict(weights.state_dict(progress=progress)) + + if quantize: + torch.quantization.convert(model, inplace=True) + model.eval() + + return model + + +class QuantizedMobileNetV3LargeWeights(Weights): + ImageNet1K_QNNPACK_RefV1 = WeightEntry( + url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", + transforms=partial(ImageNetEval, crop_size=224), + meta={ + "size": (224, 224), + "categories": _IMAGENET_CATEGORIES, + "interpolation": InterpolationMode.BILINEAR, + "backend": "qnnpack", + "quantization": "qat", + "recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", + "unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1, + "acc@1": 73.004, + "acc@5": 90.858, + }, + ) + + +def mobilenet_v3_large( + weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None, + progress: bool = True, + quantize: bool = False, + **kwargs: Any, +) -> QuantizableMobileNetV3: + if "pretrained" in kwargs: + warnings.warn("The argument pretrained is deprecated, please use weights instead.") + if kwargs.pop("pretrained"): + weights = ( + QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 + if quantize + else MobileNetV3LargeWeights.ImageNet1K_RefV1 + ) + else: + weights = None + + if quantize: + weights = QuantizedMobileNetV3LargeWeights.verify(weights) + else: + weights = MobileNetV3LargeWeights.verify(weights) + + inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) + return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)