Skip to content

Commit 5d1e568

Browse files
prabhat00155datumbox
authored andcommitted
[fbsync] Add a test that compares the output of our quantized models against expected cached values (#4597)
Summary: * adding tests to check output of quantized models * adding test quantized model weights * merge test_new_quantized_classification_model with test_quantized_classification_model * adding skipif removed by mistake * addressing comments from PR * removing unused argument * fixing lint errors * changing model to eval model and updating weights * Update test/test_models.py * enforce single test in circleci * changing random seed * updating weights for new seed * adding missing empty line * try 128 random seed * try 256 random seed * try 16 random seed * disable inception_v3 input/output quantization tests * removing ModelTester.test_inception_v3_quantized_expect.pkl * reverting temporary ci run_test.sh changes Reviewed By: fmassa Differential Revision: D31649962 fbshipit-source-id: 35a0cb4d8d3564c88dabc09e750d5ad0a281431a Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 1486c54 commit 5d1e568

11 files changed

+15
-1
lines changed
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

test/test_models.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,11 @@ def _check_input_backprop(model, inputs):
220220
"maskrcnn_resnet50_fpn",
221221
)
222222

223+
# The tests for the following quantized models are flaky possibly due to inconsistent
224+
# rounding errors in different platforms. For this reason the input/output consistency
225+
# tests under test_quantized_classification_model will be skipped for the following models.
226+
quantized_flaky_models = ("inception_v3",)
227+
223228

224229
# The following contains configuration parameters for all models which are used by
225230
# the _test_*_model methods.
@@ -687,7 +692,9 @@ def test_video_model(model_name, dev):
687692
)
688693
@pytest.mark.parametrize("model_name", get_available_quantizable_models())
689694
def test_quantized_classification_model(model_name):
695+
set_rng_seed(0)
690696
defaults = {
697+
"num_classes": 5,
691698
"input_shape": (1, 3, 224, 224),
692699
"pretrained": False,
693700
"quantize": True,
@@ -697,8 +704,15 @@ def test_quantized_classification_model(model_name):
697704

698705
# First check if quantize=True provides models that can run with input data
699706
model = torchvision.models.quantization.__dict__[model_name](**kwargs)
707+
model.eval()
700708
x = torch.rand(input_shape)
701-
model(x)
709+
out = model(x)
710+
711+
if model_name not in quantized_flaky_models:
712+
_assert_expected(out, model_name + "_quantized", prec=0.1)
713+
assert out.shape[-1] == 5
714+
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(model_name, None))
715+
_check_fx_compatible(model, x)
702716

703717
kwargs["quantize"] = False
704718
for eval_mode in [True, False]:

0 commit comments

Comments
 (0)