19
19
20
20
class QuantizableSqueezeExcitation (SElayer ):
21
21
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
22
- super ().__init__ (* args , ** kwargs )
22
+ super ().__init__ (
23
+ scale_activation = nn .Hardswish ,
24
+ * args ,
25
+ ** kwargs
26
+ )
23
27
self .skip_mul = nn .quantized .FloatFunctional ()
24
28
25
29
def forward (self , input : Tensor ) -> Tensor :
@@ -80,11 +84,12 @@ def _load_weights(
80
84
model : QuantizableMobileNetV3 ,
81
85
model_url : Optional [str ],
82
86
progress : bool ,
87
+ strict : bool
83
88
) -> None :
84
89
if model_url is None :
85
90
raise ValueError ("No checkpoint is available for {}" .format (arch ))
86
91
state_dict = load_state_dict_from_url (model_url , progress = progress )
87
- model .load_state_dict (state_dict )
92
+ model .load_state_dict (state_dict , strict = strict )
88
93
89
94
90
95
def _mobilenet_v3_model (
@@ -108,13 +113,13 @@ def _mobilenet_v3_model(
108
113
torch .quantization .prepare_qat (model , inplace = True )
109
114
110
115
if pretrained :
111
- _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress )
116
+ _load_weights (arch , model , quant_model_urls .get (arch + '_' + backend , None ), progress , False )
112
117
113
118
torch .quantization .convert (model , inplace = True )
114
119
model .eval ()
115
120
else :
116
121
if pretrained :
117
- _load_weights (arch , model , model_urls .get (arch , None ), progress )
122
+ _load_weights (arch , model , model_urls .get (arch , None ), progress , True )
118
123
119
124
return model
120
125
0 commit comments