@@ -220,6 +220,11 @@ def _check_input_backprop(model, inputs):
220
220
"maskrcnn_resnet50_fpn" ,
221
221
)
222
222
223
+ # The tests for the following quantized models are flaky possibly due to inconsistent
224
+ # rounding errors in different platforms. For this reason the input/output consistency
225
+ # tests under test_quantized_classification_model will be skipped for the following models.
226
+ quantized_flaky_models = ("inception_v3" ,)
227
+
223
228
224
229
# The following contains configuration parameters for all models which are used by
225
230
# the _test_*_model methods.
@@ -687,7 +692,9 @@ def test_video_model(model_name, dev):
687
692
)
688
693
@pytest .mark .parametrize ("model_name" , get_available_quantizable_models ())
689
694
def test_quantized_classification_model (model_name ):
695
+ set_rng_seed (0 )
690
696
defaults = {
697
+ "num_classes" : 5 ,
691
698
"input_shape" : (1 , 3 , 224 , 224 ),
692
699
"pretrained" : False ,
693
700
"quantize" : True ,
@@ -697,8 +704,15 @@ def test_quantized_classification_model(model_name):
697
704
698
705
# First check if quantize=True provides models that can run with input data
699
706
model = torchvision .models .quantization .__dict__ [model_name ](** kwargs )
707
+ model .eval ()
700
708
x = torch .rand (input_shape )
701
- model (x )
709
+ out = model (x )
710
+
711
+ if model_name not in quantized_flaky_models :
712
+ _assert_expected (out , model_name + "_quantized" , prec = 0.1 )
713
+ assert out .shape [- 1 ] == 5
714
+ _check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (model_name , None ))
715
+ _check_fx_compatible (model , x )
702
716
703
717
kwargs ["quantize" ] = False
704
718
for eval_mode in [True , False ]:
0 commit comments