Skip to content

Commit 45d8608

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [proto] Added center_crop_bounding_box functional op (#5972)
Summary: * [proto] Added `center_crop_bounding_box` functional op * Fixed mypy issue * Added one more test case * More test cases Reviewed By: YosuaMichael Differential Revision: D36281607 fbshipit-source-id: dd6a822ecb439e07e115d4d854e9a8ce7a53873d
1 parent 41a76ba commit 45d8608

File tree

3 files changed

+77
-1
lines changed

3 files changed

+77
-1
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 65 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,7 @@ def make_bounding_box(*, format, image_size=(32, 32), extra_dims=(), dtype=torch
9595
cx = torch.randint(1, width - 1, ())
9696
cy = torch.randint(1, height - 1, ())
9797
w = randint_with_tensor_bounds(1, torch.minimum(cx, width - cx) + 1)
98-
h = randint_with_tensor_bounds(1, torch.minimum(cy, width - cy) + 1)
98+
h = randint_with_tensor_bounds(1, torch.minimum(cy, height - cy) + 1)
9999
parts = (cx, cy, w, h)
100100
else:
101101
raise pytest.UsageError()
@@ -413,6 +413,14 @@ def perspective_segmentation_mask():
413413
)
414414

415415

416+
@register_kernel_info_from_sample_inputs_fn
417+
def center_crop_bounding_box():
418+
for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]):
419+
yield SampleInput(
420+
bounding_box, format=bounding_box.format, output_size=output_size, image_size=bounding_box.image_size
421+
)
422+
423+
416424
@pytest.mark.parametrize(
417425
"kernel",
418426
[
@@ -1273,3 +1281,59 @@ def _compute_expected_mask(mask, pcoeffs_):
12731281
else:
12741282
expected_masks = expected_masks[0]
12751283
torch.testing.assert_close(output_mask, expected_masks)
1284+
1285+
1286+
@pytest.mark.parametrize("device", cpu_and_gpu())
1287+
@pytest.mark.parametrize(
1288+
"output_size",
1289+
[(18, 18), [18, 15], (16, 19), [12], [46, 48]],
1290+
)
1291+
def test_correctness_center_crop_bounding_box(device, output_size):
1292+
def _compute_expected_bbox(bbox, output_size_):
1293+
format_ = bbox.format
1294+
image_size_ = bbox.image_size
1295+
bbox = convert_bounding_box_format(bbox, format_, features.BoundingBoxFormat.XYWH)
1296+
1297+
if len(output_size_) == 1:
1298+
output_size_.append(output_size_[-1])
1299+
1300+
cy = int(round((image_size_[0] - output_size_[0]) * 0.5))
1301+
cx = int(round((image_size_[1] - output_size_[1]) * 0.5))
1302+
out_bbox = [
1303+
bbox[0].item() - cx,
1304+
bbox[1].item() - cy,
1305+
bbox[2].item(),
1306+
bbox[3].item(),
1307+
]
1308+
out_bbox = features.BoundingBox(
1309+
out_bbox,
1310+
format=features.BoundingBoxFormat.XYWH,
1311+
image_size=output_size_,
1312+
dtype=bbox.dtype,
1313+
device=bbox.device,
1314+
)
1315+
return convert_bounding_box_format(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False)
1316+
1317+
for bboxes in make_bounding_boxes(
1318+
image_sizes=[(32, 32), (24, 33), (32, 25)],
1319+
extra_dims=((4,),),
1320+
):
1321+
bboxes = bboxes.to(device)
1322+
bboxes_format = bboxes.format
1323+
bboxes_image_size = bboxes.image_size
1324+
1325+
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, output_size, bboxes_image_size)
1326+
1327+
if bboxes.ndim < 2:
1328+
bboxes = [bboxes]
1329+
1330+
expected_bboxes = []
1331+
for bbox in bboxes:
1332+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
1333+
expected_bboxes.append(_compute_expected_bbox(bbox, output_size))
1334+
1335+
if len(expected_bboxes) > 1:
1336+
expected_bboxes = torch.stack(expected_bboxes)
1337+
else:
1338+
expected_bboxes = expected_bboxes[0]
1339+
torch.testing.assert_close(output_boxes, expected_bboxes)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
resize_image_tensor,
4646
resize_image_pil,
4747
resize_segmentation_mask,
48+
center_crop_bounding_box,
4849
center_crop_image_tensor,
4950
center_crop_image_pil,
5051
resized_crop_bounding_box,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,17 @@ def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.I
619619
return crop_image_pil(img, crop_top, crop_left, crop_height, crop_width)
620620

621621

622+
def center_crop_bounding_box(
623+
bounding_box: torch.Tensor,
624+
format: features.BoundingBoxFormat,
625+
output_size: List[int],
626+
image_size: Tuple[int, int],
627+
) -> torch.Tensor:
628+
crop_height, crop_width = _center_crop_parse_output_size(output_size)
629+
crop_top, crop_left = _center_crop_compute_crop_anchor(crop_height, crop_width, *image_size)
630+
return crop_bounding_box(bounding_box, format, top=crop_top, left=crop_left)
631+
632+
622633
def resized_crop_image_tensor(
623634
img: torch.Tensor,
624635
top: int,

0 commit comments

Comments
 (0)