Skip to content

Commit 63c714d

Browse files
YosuaMichaelvfdev-5
authored andcommitted
[fbsync] add tests for F.pad_bounding_box (#6038)
Summary: * add tests for F.pad_bounding_box * Added correctness tests for pad and reimplemented bbox op to keep dtype * Update _geometry.py Reviewed By: NicolasHug Differential Revision: D36760926 fbshipit-source-id: 6c2a5430790e2db778911dbf62ce8e38f6e9b603 Co-authored-by: vfdev <[email protected]>
1 parent 7df2d08 commit 63c714d

File tree

2 files changed

+77
-24
lines changed

2 files changed

+77
-24
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 68 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -382,6 +382,15 @@ def pad_segmentation_mask():
382382
yield SampleInput(mask, padding=padding, padding_mode=padding_mode)
383383

384384

385+
@register_kernel_info_from_sample_inputs_fn
386+
def pad_bounding_box():
387+
for bounding_box, padding in itertools.product(
388+
make_bounding_boxes(),
389+
[[1], [1, 1], [1, 1, 2, 2]],
390+
):
391+
yield SampleInput(bounding_box, padding=padding, format=bounding_box.format)
392+
393+
385394
@register_kernel_info_from_sample_inputs_fn
386395
def perspective_bounding_box():
387396
for bounding_box, perspective_coeffs in itertools.product(
@@ -1103,22 +1112,67 @@ def test_correctness_pad_segmentation_mask_on_fixed_input(device):
11031112
torch.testing.assert_close(out_mask, expected_mask)
11041113

11051114

1115+
def _parse_padding(padding):
1116+
if isinstance(padding, int):
1117+
return [padding] * 4
1118+
if isinstance(padding, list):
1119+
if len(padding) == 1:
1120+
return padding * 4
1121+
if len(padding) == 2:
1122+
return padding * 2 # [left, up, right, down]
1123+
1124+
return padding
1125+
1126+
1127+
@pytest.mark.parametrize("device", cpu_and_gpu())
1128+
@pytest.mark.parametrize("padding", [[1], [1, 1], [1, 1, 2, 2]])
1129+
def test_correctness_pad_bounding_box(device, padding):
1130+
def _compute_expected_bbox(bbox, padding_):
1131+
pad_left, pad_up, _, _ = _parse_padding(padding_)
1132+
1133+
bbox_format = bbox.format
1134+
bbox_dtype = bbox.dtype
1135+
bbox = convert_bounding_box_format(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY)
1136+
1137+
bbox[0::2] += pad_left
1138+
bbox[1::2] += pad_up
1139+
1140+
bbox = convert_bounding_box_format(
1141+
bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
1142+
)
1143+
if bbox.dtype != bbox_dtype:
1144+
# Temporary cast to original dtype
1145+
# e.g. float32 -> int
1146+
bbox = bbox.to(bbox_dtype)
1147+
return bbox
1148+
1149+
for bboxes in make_bounding_boxes():
1150+
bboxes = bboxes.to(device)
1151+
bboxes_format = bboxes.format
1152+
bboxes_image_size = bboxes.image_size
1153+
1154+
output_boxes = F.pad_bounding_box(bboxes, padding, format=bboxes_format)
1155+
1156+
if bboxes.ndim < 2:
1157+
bboxes = [bboxes]
1158+
1159+
expected_bboxes = []
1160+
for bbox in bboxes:
1161+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
1162+
expected_bboxes.append(_compute_expected_bbox(bbox, padding))
1163+
1164+
if len(expected_bboxes) > 1:
1165+
expected_bboxes = torch.stack(expected_bboxes)
1166+
else:
1167+
expected_bboxes = expected_bboxes[0]
1168+
torch.testing.assert_close(output_boxes, expected_bboxes)
1169+
1170+
11061171
@pytest.mark.parametrize("padding", [[1, 2, 3, 4], [1], 1, [1, 2]])
11071172
def test_correctness_pad_segmentation_mask(padding):
1108-
def _compute_expected_mask():
1109-
def parse_padding():
1110-
if isinstance(padding, int):
1111-
return [padding] * 4
1112-
if isinstance(padding, list):
1113-
if len(padding) == 1:
1114-
return padding * 4
1115-
if len(padding) == 2:
1116-
return padding * 2 # [left, up, right, down]
1117-
1118-
return padding
1119-
1173+
def _compute_expected_mask(mask, padding_):
11201174
h, w = mask.shape[-2], mask.shape[-1]
1121-
pad_left, pad_up, pad_right, pad_down = parse_padding()
1175+
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
11221176

11231177
new_h = h + pad_up + pad_down
11241178
new_w = w + pad_left + pad_right
@@ -1132,7 +1186,7 @@ def parse_padding():
11321186
for mask in make_segmentation_masks():
11331187
out_mask = F.pad_segmentation_mask(mask, padding, "constant")
11341188

1135-
expected_mask = _compute_expected_mask()
1189+
expected_mask = _compute_expected_mask(mask, padding)
11361190
torch.testing.assert_close(out_mask, expected_mask)
11371191

11381192

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -415,16 +415,15 @@ def pad_bounding_box(
415415
) -> torch.Tensor:
416416
left, _, top, _ = _FT._parse_pad_padding(padding)
417417

418-
bounding_box = convert_bounding_box_format(
419-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
420-
)
421-
422-
bounding_box[..., 0::2] += left
423-
bounding_box[..., 1::2] += top
424-
425-
return convert_bounding_box_format(
426-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
427-
)
418+
bounding_box = bounding_box.clone()
419+
420+
# this works without conversion since padding only affects xy coordinates
421+
bounding_box[..., 0] += left
422+
bounding_box[..., 1] += top
423+
if format == features.BoundingBoxFormat.XYXY:
424+
bounding_box[..., 2] += left
425+
bounding_box[..., 3] += top
426+
return bounding_box
428427

429428

430429
crop_image_tensor = _FT.crop

0 commit comments

Comments
 (0)