Skip to content

Commit 4521f6d

Browse files
authored
Refactor & enable JIT tests in all models and add warnings if skipped (#3033)
* Enable jit tests in all models and add warning if checkModule() tests are skipped. * Turning on JIT tests on CI. * Fixing broken unit-tests. * Refactoring and cleaning up duplicate code.
1 parent a51c49e commit 4521f6d

File tree

4 files changed

+34
-66
lines changed

4 files changed

+34
-66
lines changed

.circleci/unittest/linux/scripts/run_test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ set -e
55
eval "$(./conda/bin/conda shell.bash hook)"
66
conda activate ./env
77

8+
export PYTORCH_TEST_WITH_SLOW='1'
89
python -m torch.utils.collect_env
9-
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
10+
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py

.circleci/unittest/windows/scripts/run_test.sh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ set -e
55
eval "$(./conda/Scripts/conda.exe 'shell.bash' 'hook')"
66
conda activate ./env
77

8+
export PYTORCH_TEST_WITH_SLOW='1'
89
python -m torch.utils.collect_env
9-
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py
10+
pytest --cov=torchvision --junitxml=test-results/junit.xml -v --durations 20 test --ignore=test/test_datasets_download.py

test/common_utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import io
99
import torch
10-
import errno
10+
import warnings
1111
import __main__
1212

1313
from numbers import Number
@@ -265,14 +265,21 @@ def assertTensorsEqual(a, b):
265265
else:
266266
super(TestCase, self).assertEqual(x, y, message)
267267

268-
def checkModule(self, nn_module, args, unwrapper=None, skip=False):
268+
def check_jit_scriptable(self, nn_module, args, unwrapper=None, skip=False):
269269
"""
270270
Check that a nn.Module's results in TorchScript match eager and that it
271271
can be exported
272272
"""
273273
if not TEST_WITH_SLOW or skip:
274274
# TorchScript is not enabled, skip these tests
275-
return
275+
msg = "The check_jit_scriptable test for {} was skipped. " \
276+
"This test checks if the module's results in TorchScript " \
277+
"match eager and that it can be exported. To run these " \
278+
"tests make sure you set the environment variable " \
279+
"PYTORCH_TEST_WITH_SLOW=1 and that the test is not " \
280+
"manually skipped.".format(nn_module.__class__.__name__)
281+
warnings.warn(msg, RuntimeWarning)
282+
return None
276283

277284
sm = torch.jit.script(nn_module)
278285

@@ -284,7 +291,7 @@ def checkModule(self, nn_module, args, unwrapper=None, skip=False):
284291
if unwrapper:
285292
script_out = unwrapper(script_out)
286293

287-
self.assertEqual(eager_out, script_out)
294+
self.assertEqual(eager_out, script_out, prec=1e-4)
288295
self.assertExportImportModule(sm, args)
289296

290297
return sm

test/test_models.py

Lines changed: 19 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -38,44 +38,16 @@ def get_available_video_models():
3838
return [k for k, v in models.video.__dict__.items() if callable(v) and k[0].lower() == k[0] and k[0] != "_"]
3939

4040

41-
# models that are in torch hub, as well as r3d_18. we tried testing all models
42-
# but the test was too slow. not included are detection models, because
43-
# they are not yet supported in JIT.
4441
# If 'unwrapper' is provided it will be called with the script model outputs
4542
# before they are compared to the eager model outputs. This is useful if the
4643
# model outputs are different between TorchScript / Eager mode
47-
script_test_models = {
48-
'deeplabv3_resnet50': {},
49-
'deeplabv3_resnet101': {},
50-
'mobilenet_v2': {},
51-
'resnext50_32x4d': {},
52-
'fcn_resnet50': {},
53-
'fcn_resnet101': {},
54-
'googlenet': {
55-
'unwrapper': lambda x: x.logits
56-
},
57-
'densenet121': {},
58-
'resnet18': {},
59-
'alexnet': {},
60-
'shufflenet_v2_x1_0': {},
61-
'squeezenet1_0': {},
62-
'vgg11': {},
63-
'inception_v3': {
64-
'unwrapper': lambda x: x.logits
65-
},
66-
'r3d_18': {},
67-
"fasterrcnn_resnet50_fpn": {
68-
'unwrapper': lambda x: x[1]
69-
},
70-
"maskrcnn_resnet50_fpn": {
71-
'unwrapper': lambda x: x[1]
72-
},
73-
"keypointrcnn_resnet50_fpn": {
74-
'unwrapper': lambda x: x[1]
75-
},
76-
"retinanet_resnet50_fpn": {
77-
'unwrapper': lambda x: x[1]
78-
}
44+
script_model_unwrapper = {
45+
'googlenet': lambda x: x.logits,
46+
'inception_v3': lambda x: x.logits,
47+
"fasterrcnn_resnet50_fpn": lambda x: x[1],
48+
"maskrcnn_resnet50_fpn": lambda x: x[1],
49+
"keypointrcnn_resnet50_fpn": lambda x: x[1],
50+
"retinanet_resnet50_fpn": lambda x: x[1],
7951
}
8052

8153

@@ -97,12 +69,6 @@ def get_available_video_models():
9769

9870

9971
class ModelTester(TestCase):
100-
def checkModule(self, model, name, args):
101-
if name not in script_test_models:
102-
return
103-
unwrapper = script_test_models[name].get('unwrapper', None)
104-
return super(ModelTester, self).checkModule(model, args, unwrapper=unwrapper, skip=False)
105-
10672
def _test_classification_model(self, name, input_shape, dev):
10773
set_rng_seed(0)
10874
# passing num_class equal to a number other than 1000 helps in making the test
@@ -114,7 +80,7 @@ def _test_classification_model(self, name, input_shape, dev):
11480
out = model(x)
11581
self.assertExpected(out.cpu(), prec=0.1, strip_suffix="_" + dev)
11682
self.assertEqual(out.shape[-1], 50)
117-
self.checkModule(model, name, (x,))
83+
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
11884

11985
if dev == "cuda":
12086
with torch.cuda.amp.autocast():
@@ -134,7 +100,7 @@ def _test_segmentation_model(self, name, dev):
134100
x = torch.rand(input_shape).to(device=dev)
135101
out = model(x)
136102
self.assertEqual(tuple(out["out"].shape), (1, 50, 300, 300))
137-
self.checkModule(model, name, (x,))
103+
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
138104

139105
if dev == "cuda":
140106
with torch.cuda.amp.autocast():
@@ -209,18 +175,7 @@ def compute_mean_std(tensor):
209175
return True # Full validation performed
210176

211177
full_validation = check_out(out)
212-
213-
scripted_model = torch.jit.script(model)
214-
scripted_model.eval()
215-
scripted_out = scripted_model(model_input)[1]
216-
self.assertEqual(scripted_out[0]["boxes"], out[0]["boxes"])
217-
self.assertEqual(scripted_out[0]["scores"], out[0]["scores"])
218-
# labels currently float in script: need to investigate (though same result)
219-
self.assertEqual(scripted_out[0]["labels"].to(dtype=torch.long), out[0]["labels"])
220-
# don't check script because we are compiling it here:
221-
# TODO: refactor tests
222-
# self.check_script(model, name)
223-
self.checkModule(model, name, ([x],))
178+
self.check_jit_scriptable(model, ([x],), unwrapper=script_model_unwrapper.get(name, None))
224179

225180
if dev == "cuda":
226181
with torch.cuda.amp.autocast():
@@ -270,7 +225,7 @@ def _test_video_model(self, name, dev):
270225
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
271226
x = torch.rand(input_shape).to(device=dev)
272227
out = model(x)
273-
self.checkModule(model, name, (x,))
228+
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
274229
self.assertEqual(out.shape[-1], 50)
275230

276231
if dev == "cuda":
@@ -345,11 +300,13 @@ def test_inceptionv3_eval(self):
345300
kwargs['transform_input'] = True
346301
kwargs['aux_logits'] = True
347302
kwargs['init_weights'] = False
303+
name = "inception_v3"
348304
model = models.Inception3(**kwargs)
349305
model.aux_logits = False
350306
model.AuxLogits = None
351-
m = torch.jit.script(model.eval())
352-
self.checkModule(m, "inception_v3", torch.rand(1, 3, 299, 299))
307+
model = model.eval()
308+
x = torch.rand(1, 3, 299, 299)
309+
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
353310

354311
def test_fasterrcnn_double(self):
355312
model = models.detection.fasterrcnn_resnet50_fpn(num_classes=50, pretrained_backbone=False)
@@ -371,12 +328,14 @@ def test_googlenet_eval(self):
371328
kwargs['transform_input'] = True
372329
kwargs['aux_logits'] = True
373330
kwargs['init_weights'] = False
331+
name = "googlenet"
374332
model = models.GoogLeNet(**kwargs)
375333
model.aux_logits = False
376334
model.aux1 = None
377335
model.aux2 = None
378-
m = torch.jit.script(model.eval())
379-
self.checkModule(m, "googlenet", torch.rand(1, 3, 224, 224))
336+
model = model.eval()
337+
x = torch.rand(1, 3, 224, 224)
338+
self.check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
380339

381340
@unittest.skipIf(not torch.cuda.is_available(), 'needs GPU')
382341
def test_fasterrcnn_switch_devices(self):

0 commit comments

Comments
 (0)