@@ -148,6 +148,35 @@ def _check_fx_compatible(model, inputs):
148
148
torch .testing .assert_close (out , out_fx )
149
149
150
150
151
+ def _check_input_backprop (model , inputs ):
152
+ if isinstance (inputs , list ):
153
+ requires_grad = list ()
154
+ for inp in inputs :
155
+ requires_grad .append (inp .requires_grad )
156
+ inp .requires_grad_ (True )
157
+ else :
158
+ requires_grad = inputs .requires_grad
159
+ inputs .requires_grad_ (True )
160
+
161
+ out = model (inputs )
162
+
163
+ if isinstance (out , dict ):
164
+ out ["out" ].sum ().backward ()
165
+ else :
166
+ if isinstance (out [0 ], dict ):
167
+ out [0 ]["scores" ].sum ().backward ()
168
+ else :
169
+ out [0 ].sum ().backward ()
170
+
171
+ if isinstance (inputs , list ):
172
+ for i , inp in enumerate (inputs ):
173
+ assert inputs [i ].grad is not None
174
+ inp .requires_grad_ (requires_grad [i ])
175
+ else :
176
+ assert inputs .grad is not None
177
+ inputs .requires_grad_ (requires_grad )
178
+
179
+
151
180
# If 'unwrapper' is provided it will be called with the script model outputs
152
181
# before they are compared to the eager model outputs. This is useful if the
153
182
# model outputs are different between TorchScript / Eager mode
@@ -263,6 +292,9 @@ def test_memory_efficient_densenet(model_name):
263
292
assert num_params == num_grad
264
293
torch .testing .assert_close (out1 , out2 , rtol = 0.0 , atol = 1e-5 )
265
294
295
+ _check_input_backprop (model1 , x )
296
+ _check_input_backprop (model2 , x )
297
+
266
298
267
299
@pytest .mark .parametrize ('dilate_layer_2' , (True , False ))
268
300
@pytest .mark .parametrize ('dilate_layer_3' , (True , False ))
@@ -312,6 +344,7 @@ def test_inception_v3_eval():
312
344
model = model .eval ()
313
345
x = torch .rand (1 , 3 , 299 , 299 )
314
346
_check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
347
+ _check_input_backprop (model , x )
315
348
316
349
317
350
def test_fasterrcnn_double ():
@@ -327,6 +360,7 @@ def test_fasterrcnn_double():
327
360
assert "boxes" in out [0 ]
328
361
assert "scores" in out [0 ]
329
362
assert "labels" in out [0 ]
363
+ _check_input_backprop (model , model_input )
330
364
331
365
332
366
def test_googlenet_eval ():
@@ -343,6 +377,7 @@ def test_googlenet_eval():
343
377
model = model .eval ()
344
378
x = torch .rand (1 , 3 , 224 , 224 )
345
379
_check_jit_scriptable (model , (x ,), unwrapper = script_model_unwrapper .get (name , None ))
380
+ _check_input_backprop (model , x )
346
381
347
382
348
383
@needs_cuda
@@ -369,13 +404,17 @@ def checkOut(out):
369
404
370
405
checkOut (out )
371
406
407
+ _check_input_backprop (model , model_input )
408
+
372
409
# now switch to cpu and make sure it works
373
410
model .cpu ()
374
411
x = x .cpu ()
375
412
out_cpu = model ([x ])
376
413
377
414
checkOut (out_cpu )
378
415
416
+ _check_input_backprop (model , [x ])
417
+
379
418
380
419
def test_generalizedrcnn_transform_repr ():
381
420
@@ -426,6 +465,8 @@ def test_classification_model(model_name, dev):
426
465
_assert_expected (out .cpu (), model_name , prec = 0.1 )
427
466
assert out .shape [- 1 ] == 50
428
467
468
+ _check_input_backprop (model , x )
469
+
429
470
430
471
@pytest .mark .parametrize ('model_name' , get_available_segmentation_models ())
431
472
@pytest .mark .parametrize ('dev' , cpu_and_gpu ())
@@ -483,6 +524,8 @@ def check_out(out):
483
524
warnings .warn (msg , RuntimeWarning )
484
525
pytest .skip (msg )
485
526
527
+ _check_input_backprop (model , x )
528
+
486
529
487
530
@pytest .mark .parametrize ('model_name' , get_available_detection_models ())
488
531
@pytest .mark .parametrize ('dev' , cpu_and_gpu ())
@@ -574,6 +617,8 @@ def compute_mean_std(tensor):
574
617
warnings .warn (msg , RuntimeWarning )
575
618
pytest .skip (msg )
576
619
620
+ _check_input_backprop (model , model_input )
621
+
577
622
578
623
@pytest .mark .parametrize ('model_name' , get_available_detection_models ())
579
624
def test_detection_model_validation (model_name ):
@@ -625,6 +670,8 @@ def test_video_model(model_name, dev):
625
670
out = model (x )
626
671
assert out .shape [- 1 ] == 50
627
672
673
+ _check_input_backprop (model , x )
674
+
628
675
629
676
@pytest .mark .skipif (not ('fbgemm' in torch .backends .quantized .supported_engines and
630
677
'qnnpack' in torch .backends .quantized .supported_engines ),
0 commit comments