Skip to content

Commit eb48a1d

Browse files
authored
Adding multiweight support to Quantized MobileNetV2 and MobileNetV3 (#4859)
* Adding multiweight suport on Quant MobileNetV2 and MobileNetV3. * Fixing enum name. * Fixing lint.
1 parent 50a3571 commit eb48a1d

File tree

6 files changed

+194
-2
lines changed

6 files changed

+194
-2
lines changed

torchvision/prototype/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,7 @@
44
from .googlenet import *
55
from .inception import *
66
from .mnasnet import *
7-
from .mobilenetv2 import *
8-
from .mobilenetv3 import *
7+
from .mobilenet import *
98
from .regnet import *
109
from .resnet import *
1110
from .shufflenetv2 import *
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .mobilenetv2 import * # noqa: F401, F403
2+
from .mobilenetv3 import * # noqa: F401, F403
3+
from .mobilenetv2 import __all__ as mv2_all
4+
from .mobilenetv3 import __all__ as mv3_all
5+
6+
__all__ = mv2_all + mv3_all
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from .googlenet import *
22
from .inception import *
3+
from .mobilenet import *
34
from .resnet import *
45
from .shufflenetv2 import *
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .mobilenetv2 import * # noqa: F401, F403
2+
from .mobilenetv3 import * # noqa: F401, F403
3+
from .mobilenetv2 import __all__ as mv2_all
4+
from .mobilenetv3 import __all__ as mv3_all
5+
6+
__all__ = mv2_all + mv3_all
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, Optional, Union
4+
5+
from torchvision.transforms.functional import InterpolationMode
6+
7+
from ....models.quantization.mobilenetv2 import (
8+
QuantizableInvertedResidual,
9+
QuantizableMobileNetV2,
10+
_replace_relu,
11+
quantize_model,
12+
)
13+
from ...transforms.presets import ImageNetEval
14+
from .._api import Weights, WeightEntry
15+
from .._meta import _IMAGENET_CATEGORIES
16+
from ..mobilenetv2 import MobileNetV2Weights
17+
18+
19+
__all__ = [
20+
"QuantizableMobileNetV2",
21+
"QuantizedMobileNetV2Weights",
22+
"mobilenet_v2",
23+
]
24+
25+
26+
class QuantizedMobileNetV2Weights(Weights):
27+
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
28+
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth",
29+
transforms=partial(ImageNetEval, crop_size=224),
30+
meta={
31+
"size": (224, 224),
32+
"categories": _IMAGENET_CATEGORIES,
33+
"interpolation": InterpolationMode.BILINEAR,
34+
"backend": "qnnpack",
35+
"quantization": "qat",
36+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2",
37+
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1,
38+
"acc@1": 71.658,
39+
"acc@5": 90.150,
40+
},
41+
)
42+
43+
44+
def mobilenet_v2(
45+
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None,
46+
progress: bool = True,
47+
quantize: bool = False,
48+
**kwargs: Any,
49+
) -> QuantizableMobileNetV2:
50+
if "pretrained" in kwargs:
51+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
52+
if kwargs.pop("pretrained"):
53+
weights = (
54+
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1
55+
if quantize
56+
else MobileNetV2Weights.ImageNet1K_RefV1
57+
)
58+
else:
59+
weights = None
60+
61+
if quantize:
62+
weights = QuantizedMobileNetV2Weights.verify(weights)
63+
else:
64+
weights = MobileNetV2Weights.verify(weights)
65+
66+
if weights is not None:
67+
kwargs["num_classes"] = len(weights.meta["categories"])
68+
if "backend" in weights.meta:
69+
kwargs["backend"] = weights.meta["backend"]
70+
backend = kwargs.pop("backend", "qnnpack")
71+
72+
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs)
73+
_replace_relu(model)
74+
if quantize:
75+
quantize_model(model, backend)
76+
77+
if weights is not None:
78+
model.load_state_dict(weights.state_dict(progress=progress))
79+
80+
return model
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
import warnings
2+
from functools import partial
3+
from typing import Any, List, Optional, Union
4+
5+
import torch
6+
from torchvision.transforms.functional import InterpolationMode
7+
8+
from ....models.quantization.mobilenetv3 import (
9+
InvertedResidualConfig,
10+
QuantizableInvertedResidual,
11+
QuantizableMobileNetV3,
12+
_replace_relu,
13+
)
14+
from ...transforms.presets import ImageNetEval
15+
from .._api import Weights, WeightEntry
16+
from .._meta import _IMAGENET_CATEGORIES
17+
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf
18+
19+
20+
__all__ = [
21+
"QuantizableMobileNetV3",
22+
"QuantizedMobileNetV3LargeWeights",
23+
"mobilenet_v3_large",
24+
]
25+
26+
27+
def _mobilenet_v3_model(
28+
inverted_residual_setting: List[InvertedResidualConfig],
29+
last_channel: int,
30+
weights: Optional[Weights],
31+
progress: bool,
32+
quantize: bool,
33+
**kwargs: Any,
34+
) -> QuantizableMobileNetV3:
35+
if weights is not None:
36+
kwargs["num_classes"] = len(weights.meta["categories"])
37+
if "backend" in weights.meta:
38+
kwargs["backend"] = weights.meta["backend"]
39+
backend = kwargs.pop("backend", "qnnpack")
40+
41+
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs)
42+
_replace_relu(model)
43+
44+
if quantize:
45+
model.fuse_model()
46+
model.qconfig = torch.quantization.get_default_qat_qconfig(backend)
47+
torch.quantization.prepare_qat(model, inplace=True)
48+
49+
if weights is not None:
50+
model.load_state_dict(weights.state_dict(progress=progress))
51+
52+
if quantize:
53+
torch.quantization.convert(model, inplace=True)
54+
model.eval()
55+
56+
return model
57+
58+
59+
class QuantizedMobileNetV3LargeWeights(Weights):
60+
ImageNet1K_QNNPACK_RefV1 = WeightEntry(
61+
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth",
62+
transforms=partial(ImageNetEval, crop_size=224),
63+
meta={
64+
"size": (224, 224),
65+
"categories": _IMAGENET_CATEGORIES,
66+
"interpolation": InterpolationMode.BILINEAR,
67+
"backend": "qnnpack",
68+
"quantization": "qat",
69+
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3",
70+
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1,
71+
"acc@1": 73.004,
72+
"acc@5": 90.858,
73+
},
74+
)
75+
76+
77+
def mobilenet_v3_large(
78+
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None,
79+
progress: bool = True,
80+
quantize: bool = False,
81+
**kwargs: Any,
82+
) -> QuantizableMobileNetV3:
83+
if "pretrained" in kwargs:
84+
warnings.warn("The argument pretrained is deprecated, please use weights instead.")
85+
if kwargs.pop("pretrained"):
86+
weights = (
87+
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1
88+
if quantize
89+
else MobileNetV3LargeWeights.ImageNet1K_RefV1
90+
)
91+
else:
92+
weights = None
93+
94+
if quantize:
95+
weights = QuantizedMobileNetV3LargeWeights.verify(weights)
96+
else:
97+
weights = MobileNetV3LargeWeights.verify(weights)
98+
99+
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs)
100+
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs)

0 commit comments

Comments
 (0)