Skip to content

Commit 2e0949e

Browse files
beat-buesserfmassadatumbox
authored
Allow gradient backpropagation through GeneralizedRCNNTransform to inputs (#4327)
* Allow gradient backpropagation through GeneralizedRCNNTransform to inputs Signed-off-by: Beat Buesser <[email protected]> * Add unit tests for gradient backpropagation to inputs Signed-off-by: Beat Buesser <[email protected]> * Update torchvision/models/detection/transform.py Co-authored-by: Francisco Massa <[email protected]> * Update _check_input_backprop Signed-off-by: Beat Buesser <[email protected]> * Account for tests requiring cuda Signed-off-by: Beat Buesser <[email protected]> Co-authored-by: Francisco Massa <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 526a69e commit 2e0949e

File tree

2 files changed

+50
-2
lines changed

2 files changed

+50
-2
lines changed

test/test_models.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,35 @@ def _check_fx_compatible(model, inputs):
148148
torch.testing.assert_close(out, out_fx)
149149

150150

151+
def _check_input_backprop(model, inputs):
152+
if isinstance(inputs, list):
153+
requires_grad = list()
154+
for inp in inputs:
155+
requires_grad.append(inp.requires_grad)
156+
inp.requires_grad_(True)
157+
else:
158+
requires_grad = inputs.requires_grad
159+
inputs.requires_grad_(True)
160+
161+
out = model(inputs)
162+
163+
if isinstance(out, dict):
164+
out["out"].sum().backward()
165+
else:
166+
if isinstance(out[0], dict):
167+
out[0]["scores"].sum().backward()
168+
else:
169+
out[0].sum().backward()
170+
171+
if isinstance(inputs, list):
172+
for i, inp in enumerate(inputs):
173+
assert inputs[i].grad is not None
174+
inp.requires_grad_(requires_grad[i])
175+
else:
176+
assert inputs.grad is not None
177+
inputs.requires_grad_(requires_grad)
178+
179+
151180
# If 'unwrapper' is provided it will be called with the script model outputs
152181
# before they are compared to the eager model outputs. This is useful if the
153182
# model outputs are different between TorchScript / Eager mode
@@ -263,6 +292,9 @@ def test_memory_efficient_densenet(model_name):
263292
assert num_params == num_grad
264293
torch.testing.assert_close(out1, out2, rtol=0.0, atol=1e-5)
265294

295+
_check_input_backprop(model1, x)
296+
_check_input_backprop(model2, x)
297+
266298

267299
@pytest.mark.parametrize('dilate_layer_2', (True, False))
268300
@pytest.mark.parametrize('dilate_layer_3', (True, False))
@@ -312,6 +344,7 @@ def test_inception_v3_eval():
312344
model = model.eval()
313345
x = torch.rand(1, 3, 299, 299)
314346
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
347+
_check_input_backprop(model, x)
315348

316349

317350
def test_fasterrcnn_double():
@@ -327,6 +360,7 @@ def test_fasterrcnn_double():
327360
assert "boxes" in out[0]
328361
assert "scores" in out[0]
329362
assert "labels" in out[0]
363+
_check_input_backprop(model, model_input)
330364

331365

332366
def test_googlenet_eval():
@@ -343,6 +377,7 @@ def test_googlenet_eval():
343377
model = model.eval()
344378
x = torch.rand(1, 3, 224, 224)
345379
_check_jit_scriptable(model, (x,), unwrapper=script_model_unwrapper.get(name, None))
380+
_check_input_backprop(model, x)
346381

347382

348383
@needs_cuda
@@ -369,13 +404,17 @@ def checkOut(out):
369404

370405
checkOut(out)
371406

407+
_check_input_backprop(model, model_input)
408+
372409
# now switch to cpu and make sure it works
373410
model.cpu()
374411
x = x.cpu()
375412
out_cpu = model([x])
376413

377414
checkOut(out_cpu)
378415

416+
_check_input_backprop(model, [x])
417+
379418

380419
def test_generalizedrcnn_transform_repr():
381420

@@ -426,6 +465,8 @@ def test_classification_model(model_name, dev):
426465
_assert_expected(out.cpu(), model_name, prec=0.1)
427466
assert out.shape[-1] == 50
428467

468+
_check_input_backprop(model, x)
469+
429470

430471
@pytest.mark.parametrize('model_name', get_available_segmentation_models())
431472
@pytest.mark.parametrize('dev', cpu_and_gpu())
@@ -483,6 +524,8 @@ def check_out(out):
483524
warnings.warn(msg, RuntimeWarning)
484525
pytest.skip(msg)
485526

527+
_check_input_backprop(model, x)
528+
486529

487530
@pytest.mark.parametrize('model_name', get_available_detection_models())
488531
@pytest.mark.parametrize('dev', cpu_and_gpu())
@@ -574,6 +617,8 @@ def compute_mean_std(tensor):
574617
warnings.warn(msg, RuntimeWarning)
575618
pytest.skip(msg)
576619

620+
_check_input_backprop(model, model_input)
621+
577622

578623
@pytest.mark.parametrize('model_name', get_available_detection_models())
579624
def test_detection_model_validation(model_name):
@@ -625,6 +670,8 @@ def test_video_model(model_name, dev):
625670
out = model(x)
626671
assert out.shape[-1] == 50
627672

673+
_check_input_backprop(model, x)
674+
628675

629676
@pytest.mark.skipif(not ('fbgemm' in torch.backends.quantized.supported_engines and
630677
'qnnpack' in torch.backends.quantized.supported_engines),

torchvision/models/detection/transform.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -214,8 +214,9 @@ def batch_images(self, images: List[Tensor], size_divisible: int = 32) -> Tensor
214214

215215
batch_shape = [len(images)] + max_size
216216
batched_imgs = images[0].new_full(batch_shape, 0)
217-
for img, pad_img in zip(images, batched_imgs):
218-
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
217+
for i in range(batched_imgs.shape[0]):
218+
img = images[i]
219+
batched_imgs[i, : img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
219220

220221
return batched_imgs
221222

0 commit comments

Comments
 (0)