-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Adding multiweight support to Quantized MobileNetV2 and MobileNetV3 #4859
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .mobilenetv2 import * # noqa: F401, F403 | ||
from .mobilenetv3 import * # noqa: F401, F403 | ||
from .mobilenetv2 import __all__ as mv2_all | ||
from .mobilenetv3 import __all__ as mv3_all | ||
|
||
__all__ = mv2_all + mv3_all |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,5 @@ | ||
from .googlenet import * | ||
from .inception import * | ||
from .mobilenet import * | ||
from .resnet import * | ||
from .shufflenetv2 import * |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
from .mobilenetv2 import * # noqa: F401, F403 | ||
from .mobilenetv3 import * # noqa: F401, F403 | ||
from .mobilenetv2 import __all__ as mv2_all | ||
from .mobilenetv3 import __all__ as mv3_all | ||
|
||
__all__ = mv2_all + mv3_all |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Any, Optional, Union | ||
|
||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from ....models.quantization.mobilenetv2 import ( | ||
QuantizableInvertedResidual, | ||
QuantizableMobileNetV2, | ||
_replace_relu, | ||
quantize_model, | ||
) | ||
from ...transforms.presets import ImageNetEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _IMAGENET_CATEGORIES | ||
from ..mobilenetv2 import MobileNetV2Weights | ||
|
||
|
||
__all__ = [ | ||
"QuantizableMobileNetV2", | ||
"QuantizedMobileNetV2Weights", | ||
"mobilenet_v2", | ||
] | ||
|
||
|
||
class QuantizedMobileNetV2Weights(Weights): | ||
ImageNet1K_QNNPACK_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/quantized/mobilenet_v2_qnnpack_37f702c5.pth", | ||
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"backend": "qnnpack", | ||
"quantization": "qat", | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv2", | ||
"unquantized": MobileNetV2Weights.ImageNet1K_RefV1, | ||
"acc@1": 71.658, | ||
"acc@5": 90.150, | ||
}, | ||
) | ||
|
||
|
||
def mobilenet_v2( | ||
weights: Optional[Union[QuantizedMobileNetV2Weights, MobileNetV2Weights]] = None, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableMobileNetV2: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
if kwargs.pop("pretrained"): | ||
weights = ( | ||
QuantizedMobileNetV2Weights.ImageNet1K_QNNPACK_RefV1 | ||
if quantize | ||
else MobileNetV2Weights.ImageNet1K_RefV1 | ||
) | ||
else: | ||
weights = None | ||
|
||
if quantize: | ||
weights = QuantizedMobileNetV2Weights.verify(weights) | ||
else: | ||
weights = MobileNetV2Weights.verify(weights) | ||
|
||
if weights is not None: | ||
kwargs["num_classes"] = len(weights.meta["categories"]) | ||
if "backend" in weights.meta: | ||
kwargs["backend"] = weights.meta["backend"] | ||
backend = kwargs.pop("backend", "qnnpack") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The backend here is |
||
|
||
model = QuantizableMobileNetV2(block=QuantizableInvertedResidual, **kwargs) | ||
_replace_relu(model) | ||
if quantize: | ||
quantize_model(model, backend) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
|
||
return model |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
import warnings | ||
from functools import partial | ||
from typing import Any, List, Optional, Union | ||
|
||
import torch | ||
from torchvision.transforms.functional import InterpolationMode | ||
|
||
from ....models.quantization.mobilenetv3 import ( | ||
InvertedResidualConfig, | ||
QuantizableInvertedResidual, | ||
QuantizableMobileNetV3, | ||
_replace_relu, | ||
) | ||
from ...transforms.presets import ImageNetEval | ||
from .._api import Weights, WeightEntry | ||
from .._meta import _IMAGENET_CATEGORIES | ||
from ..mobilenetv3 import MobileNetV3LargeWeights, _mobilenet_v3_conf | ||
|
||
|
||
__all__ = [ | ||
"QuantizableMobileNetV3", | ||
"QuantizedMobileNetV3LargeWeights", | ||
"mobilenet_v3_large", | ||
] | ||
|
||
|
||
def _mobilenet_v3_model( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I kept this method instead on dumping everything in the public method because on the future we might want to support also the |
||
inverted_residual_setting: List[InvertedResidualConfig], | ||
last_channel: int, | ||
weights: Optional[Weights], | ||
progress: bool, | ||
quantize: bool, | ||
**kwargs: Any, | ||
) -> QuantizableMobileNetV3: | ||
if weights is not None: | ||
kwargs["num_classes"] = len(weights.meta["categories"]) | ||
if "backend" in weights.meta: | ||
kwargs["backend"] = weights.meta["backend"] | ||
backend = kwargs.pop("backend", "qnnpack") | ||
|
||
model = QuantizableMobileNetV3(inverted_residual_setting, last_channel, block=QuantizableInvertedResidual, **kwargs) | ||
_replace_relu(model) | ||
|
||
if quantize: | ||
model.fuse_model() | ||
model.qconfig = torch.quantization.get_default_qat_qconfig(backend) | ||
torch.quantization.prepare_qat(model, inplace=True) | ||
|
||
if weights is not None: | ||
model.load_state_dict(weights.state_dict(progress=progress)) | ||
|
||
if quantize: | ||
torch.quantization.convert(model, inplace=True) | ||
model.eval() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Check carefully lines 44-54 comparing to original. This is a simplification of the code. |
||
|
||
return model | ||
|
||
|
||
class QuantizedMobileNetV3LargeWeights(Weights): | ||
ImageNet1K_QNNPACK_RefV1 = WeightEntry( | ||
url="https://download.pytorch.org/models/quantized/mobilenet_v3_large_qnnpack-5bcacf28.pth", | ||
transforms=partial(ImageNetEval, crop_size=224), | ||
meta={ | ||
"size": (224, 224), | ||
"categories": _IMAGENET_CATEGORIES, | ||
"interpolation": InterpolationMode.BILINEAR, | ||
"backend": "qnnpack", | ||
"quantization": "qat", | ||
"recipe": "https://github.com/pytorch/vision/tree/main/references/classification#qat-mobilenetv3", | ||
"unquantized": MobileNetV3LargeWeights.ImageNet1K_RefV1, | ||
"acc@1": 73.004, | ||
"acc@5": 90.858, | ||
}, | ||
) | ||
|
||
|
||
def mobilenet_v3_large( | ||
weights: Optional[Union[QuantizedMobileNetV3LargeWeights, MobileNetV3LargeWeights]] = None, | ||
progress: bool = True, | ||
quantize: bool = False, | ||
**kwargs: Any, | ||
) -> QuantizableMobileNetV3: | ||
if "pretrained" in kwargs: | ||
warnings.warn("The argument pretrained is deprecated, please use weights instead.") | ||
if kwargs.pop("pretrained"): | ||
weights = ( | ||
QuantizedMobileNetV3LargeWeights.ImageNet1K_QNNPACK_RefV1 | ||
if quantize | ||
else MobileNetV3LargeWeights.ImageNet1K_RefV1 | ||
) | ||
else: | ||
weights = None | ||
|
||
if quantize: | ||
weights = QuantizedMobileNetV3LargeWeights.verify(weights) | ||
else: | ||
weights = MobileNetV3LargeWeights.verify(weights) | ||
|
||
inverted_residual_setting, last_channel = _mobilenet_v3_conf("mobilenet_v3_large", **kwargs) | ||
return _mobilenet_v3_model(inverted_residual_setting, last_channel, weights, progress, quantize, **kwargs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For BC reasons, we need to maintain the
.mobilenet
space.