Skip to content

Commit 69d462e

Browse files
committed
Passing the right activation on quantization.
1 parent 72cecb1 commit 69d462e

File tree

2 files changed

+10
-4
lines changed

2 files changed

+10
-4
lines changed

torchvision/models/mobilenetv3.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ def __init__(self, input_channels: int, squeeze_factor: int = 4):
2626
squeeze_channels = _make_divisible(input_channels // squeeze_factor, 8)
2727
super().__init__(input_channels, squeeze_channels, scale_activation=nn.Hardsigmoid)
2828
self.relu = self.activation
29+
delattr(self, 'activation')
2930
warnings.warn(
3031
"This SqueezeExcitation class is deprecated and will be removed in future versions.", FutureWarning)
3132

torchvision/models/quantization/mobilenetv3.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@
1919

2020
class QuantizableSqueezeExcitation(SElayer):
2121
def __init__(self, *args: Any, **kwargs: Any) -> None:
22-
super().__init__(*args, **kwargs)
22+
super().__init__(
23+
scale_activation=nn.Hardswish,
24+
*args,
25+
**kwargs
26+
)
2327
self.skip_mul = nn.quantized.FloatFunctional()
2428

2529
def forward(self, input: Tensor) -> Tensor:
@@ -80,11 +84,12 @@ def _load_weights(
8084
model: QuantizableMobileNetV3,
8185
model_url: Optional[str],
8286
progress: bool,
87+
strict: bool
8388
) -> None:
8489
if model_url is None:
8590
raise ValueError("No checkpoint is available for {}".format(arch))
8691
state_dict = load_state_dict_from_url(model_url, progress=progress)
87-
model.load_state_dict(state_dict)
92+
model.load_state_dict(state_dict, strict=strict)
8893

8994

9095
def _mobilenet_v3_model(
@@ -108,13 +113,13 @@ def _mobilenet_v3_model(
108113
torch.quantization.prepare_qat(model, inplace=True)
109114

110115
if pretrained:
111-
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress)
116+
_load_weights(arch, model, quant_model_urls.get(arch + '_' + backend, None), progress, False)
112117

113118
torch.quantization.convert(model, inplace=True)
114119
model.eval()
115120
else:
116121
if pretrained:
117-
_load_weights(arch, model, model_urls.get(arch, None), progress)
122+
_load_weights(arch, model, model_urls.get(arch, None), progress, True)
118123

119124
return model
120125

0 commit comments

Comments
 (0)