Skip to content

Commit c1da4a5

Browse files
authored
Check target boxes input on generalized_rcnn.py (#2207)
* Check target boxes input on generalized_rcnn.py * Fix target box validation in generalized_rcnn.py * Add tests for input validation of detection models
1 parent 7aea80c commit c1da4a5

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

test/test_models.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,24 @@ def compute_mean_std(tensor):
155155
# self.check_script(model, name)
156156
self.checkModule(model, name, ([x],))
157157

158+
def _test_detection_model_validation(self, name):
159+
set_rng_seed(0)
160+
model = models.detection.__dict__[name](num_classes=50, pretrained_backbone=False)
161+
input_shape = (1, 3, 300, 300)
162+
x = [torch.rand(input_shape)]
163+
164+
# validate that targets are present in training
165+
self.assertRaises(ValueError, model, x)
166+
167+
# validate type
168+
targets = [{'boxes': 0.}]
169+
self.assertRaises(ValueError, model, x, targets=targets)
170+
171+
# validate boxes shape
172+
for boxes in (torch.rand((4,)), torch.rand((1, 5))):
173+
targets = [{'boxes': boxes}]
174+
self.assertRaises(ValueError, model, x, targets=targets)
175+
158176
def _test_video_model(self, name):
159177
# the default input shape is
160178
# bs * num_channels * clip_len * h *w
@@ -303,6 +321,11 @@ def do_test(self, model_name=model_name):
303321

304322
setattr(ModelTester, "test_" + model_name, do_test)
305323

324+
def do_validation_test(self, model_name=model_name):
325+
self._test_detection_model_validation(model_name)
326+
327+
setattr(ModelTester, "test_" + model_name + "_validation", do_validation_test)
328+
306329

307330
for model_name in get_available_video_models():
308331

torchvision/models/detection/generalized_rcnn.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,19 @@ def forward(self, images, targets=None):
5757
"""
5858
if self.training and targets is None:
5959
raise ValueError("In training mode, targets should be passed")
60+
if self.training:
61+
assert targets is not None
62+
for target in targets:
63+
boxes = target["boxes"]
64+
if isinstance(boxes, torch.Tensor):
65+
if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
66+
raise ValueError("Expected target boxes to be a tensor"
67+
"of shape [N, 4], got {:}.".format(
68+
boxes.shape))
69+
else:
70+
raise ValueError("Expected target boxes to be of type "
71+
"Tensor, got {:}.".format(type(boxes)))
72+
6073
original_image_sizes = torch.jit.annotate(List[Tuple[int, int]], [])
6174
for img in images:
6275
val = img.shape[-2:]

0 commit comments

Comments
 (0)