Skip to content

Adding multiweight support to Quantized MobileNetV2 and MobileNetV3 #4859

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions torchvision/prototype/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
from .googlenet import *
from .inception import *
from .mnasnet import *
from .mobilenetv2 import *
from .mobilenetv3 import *
from .mobilenet import *
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For BC reasons, we need to maintain the .mobilenet space.

from .regnet import *
from .resnet import *
from .shufflenetv2 import *
Expand Down
6 changes: 6 additions & 0 deletions torchvision/prototype/models/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all

__all__ = mv2_all + mv3_all
1 change: 1 addition & 0 deletions torchvision/prototype/models/quantization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .googlenet import *
from .inception import *
from .mobilenet import *
from .resnet import *
from .shufflenetv2 import *
6 changes: 6 additions & 0 deletions torchvision/prototype/models/quantization/mobilenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from .mobilenetv2 import * # noqa: F401, F403
from .mobilenetv3 import * # noqa: F401, F403
from .mobilenetv2 import __all__ as mv2_all
from .mobilenetv3 import __all__ as mv3_all

__all__ = mv2_all + mv3_all
80 changes: 80 additions & 0 deletions torchvision/prototype/models/quantization/mobilenetv2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import warnings
from functools import partial
from typing import Any, Optional, Union

from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.mobilenetv2 import (
QuantizableInvertedResidual,
QuantizableMobileNetV2,
_replace_relu,
quantize_model,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv2 import MobileNetV2Weights


__all__ = [
"QuantizableMobileNetV2",
"QuantizedMobileNetV2Weights",
"mobilenet_v2",
]


class QuantizedMobileNetV2Weights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1,
"acc@1": 71.658,
"acc@5": 90.150,
},
)


def mobilenet_v2(
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV2:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV2Weights.ImageNet1K_RefV1
)
else:
weights = None

if quantize:
weights = QuantizedMobileNetV2Weights.verify(weights)
else:
weights = MobileNetV2Weights.verify(weights)

if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "qnnpack")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The backend here is qnnpack


model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)
if quantize:
quantize_model(model, backend)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))

return model
100 changes: 100 additions & 0 deletions torchvision/prototype/models/quantization/mobilenetv3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import warnings
from functools import partial
from typing import Any, List, Optional, Union

import torch
from torchvision.transforms.functional import InterpolationMode

from ....models.quantization.mobilenetv3 import (
InvertedResidualConfig,
QuantizableInvertedResidual,
QuantizableMobileNetV3,
_replace_relu,
)
from ...transforms.presets import ImageNetEval
from .._api import Weights, WeightEntry
from .._meta import _IMAGENET_CATEGORIES
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf


__all__ = [
"QuantizableMobileNetV3",
"QuantizedMobileNetV3LargeWeights",
"mobilenet_v3_large",
]


def _mobilenet_v3_model(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kept this method instead on dumping everything in the public method because on the future we might want to support also the _small version of the model and this method will remain unchanged.

inverted_residual_setting: List[InvertedResidualConfig],
last_channel: int,
weights: Optional[Weights],
progress: bool,
quantize: bool,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if weights is not None:
kwargs["num_classes"] = len(weights.meta["categories"])
if "backend" in weights.meta:
kwargs["backend"] = weights.meta["backend"]
backend = kwargs.pop("backend", "qnnpack")

model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
_replace_relu(model)

if quantize:
model.fuse_model()
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
torch.quantization.prepare_qat(model, inplace=True)

if weights is not None:
model.load_state_dict(weights.state_dict(progress=progress))

if quantize:
torch.quantization.convert(model, inplace=True)
model.eval()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check carefully lines 44-54 comparing to original. This is a simplification of the code.


return model


class QuantizedMobileNetV3LargeWeights(Weights):
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
transforms=partial(ImageNetEval, crop_size=224),
meta={
"size": (224, 224),
"categories": _IMAGENET_CATEGORIES,
"interpolation": InterpolationMode.BILINEAR,
"backend": "qnnpack",
"quantization": "qat",
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1,
"acc@1": 73.004,
"acc@5": 90.858,
},
)


def mobilenet_v3_large(
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None,
progress: bool = True,
quantize: bool = False,
**kwargs: Any,
) -> QuantizableMobileNetV3:
if "pretrained" in kwargs:
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
if kwargs.pop("pretrained"):
weights = (
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
if quantize
else MobileNetV3LargeWeights.ImageNet1K_RefV1
)
else:
weights = None

if quantize:
weights = QuantizedMobileNetV3LargeWeights.verify(weights)
else:
weights = MobileNetV3LargeWeights.verify(weights)

inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)