@@ -38,44 +38,16 @@ def get_available_video_models():
38
38
return [k for k , v in models .video .__dict__ .items () if callable (v ) and k [0 ].lower () == k [0 ] and k [0 ] != "_" ]
39
39
40
40
41
- # models that are in torch hub, as well as r3d_18. we tried testing all models
42
- # but the test was too slow. not included are detection models, because
43
- # they are not yet supported in JIT.
44
41
# If 'unwrapper' is provided it will be called with the script model outputs
45
42
# before they are compared to the eager model outputs. This is useful if the
46
43
# model outputs are different between TorchScript / Eager mode
47
- script_test_models = {
48
- 'deeplabv3_resnet50' : {},
49
- 'deeplabv3_resnet101' : {},
50
- 'mobilenet_v2' : {},
51
- 'resnext50_32x4d' : {},
52
- 'fcn_resnet50' : {},
53
- 'fcn_resnet101' : {},
54
- 'googlenet' : {
55
- 'unwrapper' : lambda x : x .logits
56
- },
57
- 'densenet121' : {},
58
- 'resnet18' : {},
59
- 'alexnet' : {},
60
- 'shufflenet_v2_x1_0' : {},
61
- 'squeezenet1_0' : {},
62
- 'vgg11' : {},
63
- 'inception_v3' : {
64
- 'unwrapper' : lambda x : x .logits
65
- },
66
- 'r3d_18' : {},
67
- "fasterrcnn_resnet50_fpn" : {
68
- 'unwrapper' : lambda x : x [1 ]
69
- },
70
- "maskrcnn_resnet50_fpn" : {
71
- 'unwrapper' : lambda x : x [1 ]
72
- },
73
- "keypointrcnn_resnet50_fpn" : {
74
- 'unwrapper' : lambda x : x [1 ]
75
- },
76
- "retinanet_resnet50_fpn" : {
77
- 'unwrapper' : lambda x : x [1 ]
78
- }
44
+ script_model_unwrapper = {
45
+ 'googlenet' : lambda x : x .logits ,
46
+ 'inception_v3' : lambda x : x .logits ,
47
+ "fasterrcnn_resnet50_fpn" : lambda x : x [1 ],
48
+ "maskrcnn_resnet50_fpn" : lambda x : x [1 ],
49
+ "keypointrcnn_resnet50_fpn" : lambda x : x [1 ],
50
+ "retinanet_resnet50_fpn" : lambda x : x [1 ],
79
51
}
80
52
81
53
@@ -97,12 +69,6 @@ def get_available_video_models():
97
69
98
70
99
71
class ModelTester (TestCase ):
100
- def checkModule (self , model , name , args ):
101
- if name not in script_test_models :
102
- return
103
- unwrapper = script_test_models [name ].get ('unwrapper' , None )
104
- return super (ModelTester , self ).checkModule (model , args , unwrapper = unwrapper , skip = False )
105
-
106
72
def _test_classification_model (self , name , input_shape , dev ):
107
73
set_rng_seed (0 )
108
74
# passing num_class equal to a number other than 1000 helps in making the test
@@ -114,7 +80,7 @@ def _test_classification_model(self, name, input_shape, dev):
114
80
out = model (x )
115
81
self .assertExpected (out .cpu (), prec = 0.1 , strip_suffix = "_" + dev )
116
82
self .assertEqual (out .shape [- 1 ], 50 )
117
- self .checkModule (model , name , ( x , ))
83
+ self .check_jit_scriptable (model , ( x ,), unwrapper = script_model_unwrapper . get ( name , None ))
118
84
119
85
if dev == "cuda" :
120
86
with torch .cuda .amp .autocast ():
@@ -134,7 +100,7 @@ def _test_segmentation_model(self, name, dev):
134
100
x = torch .rand (input_shape ).to (device = dev )
135
101
out = model (x )
136
102
self .assertEqual (tuple (out ["out" ].shape ), (1 , 50 , 300 , 300 ))
137
- self .checkModule (model , name , ( x , ))
103
+ self .check_jit_scriptable (model , ( x ,), unwrapper = script_model_unwrapper . get ( name , None ))
138
104
139
105
if dev == "cuda" :
140
106
with torch .cuda .amp .autocast ():
@@ -209,18 +175,7 @@ def compute_mean_std(tensor):
209
175
return True # Full validation performed
210
176
211
177
full_validation = check_out (out )
212
-
213
- scripted_model = torch .jit .script (model )
214
- scripted_model .eval ()
215
- scripted_out = scripted_model (model_input )[1 ]
216
- self .assertEqual (scripted_out [0 ]["boxes" ], out [0 ]["boxes" ])
217
- self .assertEqual (scripted_out [0 ]["scores" ], out [0 ]["scores" ])
218
- # labels currently float in script: need to investigate (though same result)
219
- self .assertEqual (scripted_out [0 ]["labels" ].to (dtype = torch .long ), out [0 ]["labels" ])
220
- # don't check script because we are compiling it here:
221
- # TODO: refactor tests
222
- # self.check_script(model, name)
223
- self .checkModule (model , name , ([x ],))
178
+ self .check_jit_scriptable (model , ([x ],), unwrapper = script_model_unwrapper .get (name , None ))
224
179
225
180
if dev == "cuda" :
226
181
with torch .cuda .amp .autocast ():
@@ -270,7 +225,7 @@ def _test_video_model(self, name, dev):
270
225
# RNG always on CPU, to ensure x in cuda tests is bitwise identical to x in cpu tests
271
226
x = torch .rand (input_shape ).to (device = dev )
272
227
out = model (x )
273
- self .checkModule (model , name , ( x , ))
228
+ self .check_jit_scriptable (model , ( x ,), unwrapper = script_model_unwrapper . get ( name , None ))
274
229
self .assertEqual (out .shape [- 1 ], 50 )
275
230
276
231
if dev == "cuda" :
@@ -345,11 +300,13 @@ def test_inceptionv3_eval(self):
345
300
kwargs ['transform_input' ] = True
346
301
kwargs ['aux_logits' ] = True
347
302
kwargs ['init_weights' ] = False
303
+ name = "inception_v3"
348
304
model = models .Inception3 (** kwargs )
349
305
model .aux_logits = False
350
306
model .AuxLogits = None
351
- m = torch .jit .script (model .eval ())
352
- self .checkModule (m , "inception_v3" , torch .rand (1 , 3 , 299 , 299 ))
307
+ model = model .eval ()
308
+ x = torch .rand (1 , 3 , 299 , 299 )
309
+ self .check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
353
310
354
311
def test_fasterrcnn_double (self ):
355
312
model = models .detection .fasterrcnn_resnet50_fpn (num_classes = 50 , pretrained_backbone = False )
@@ -371,12 +328,14 @@ def test_googlenet_eval(self):
371
328
kwargs ['transform_input' ] = True
372
329
kwargs ['aux_logits' ] = True
373
330
kwargs ['init_weights' ] = False
331
+ name = "googlenet"
374
332
model = models .GoogLeNet (** kwargs )
375
333
model .aux_logits = False
376
334
model .aux1 = None
377
335
model .aux2 = None
378
- m = torch .jit .script (model .eval ())
379
- self .checkModule (m , "googlenet" , torch .rand (1 , 3 , 224 , 224 ))
336
+ model = model .eval ()
337
+ x = torch .rand (1 , 3 , 224 , 224 )
338
+ self .check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
380
339
381
340
@unittest .skipIf (not torch .cuda .is_available (), 'needs GPU' )
382
341
def test_fasterrcnn_switch_devices (self ):
0 commit comments