Skip to content

Commit be462be

Browse files
vfdev-5datumbox
andauthored
[proto] Added functional rotate_bounding_box op (#5638)
* [proto] Added functional rotate_bounding_box op * Fix mypy * Apply suggestions from code review Co-authored-by: Vasilis Vryniotis <[email protected]> Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent fdd3aab commit be462be

File tree

3 files changed

+255
-22
lines changed

3 files changed

+255
-22
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 168 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,21 @@ def affine_bounding_box():
248248
)
249249

250250

251+
@register_kernel_info_from_sample_inputs_fn
252+
def rotate_bounding_box():
253+
for bounding_box, angle, expand, center in itertools.product(
254+
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center
255+
):
256+
yield SampleInput(
257+
bounding_box,
258+
format=bounding_box.format,
259+
image_size=bounding_box.image_size,
260+
angle=angle,
261+
expand=expand,
262+
center=center,
263+
)
264+
265+
251266
@pytest.mark.parametrize(
252267
"kernel",
253268
[
@@ -330,7 +345,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
330345
np.max(transformed_points[:, 1]),
331346
]
332347
out_bbox = features.BoundingBox(
333-
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32
348+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=bbox.image_size, dtype=torch.float32
334349
)
335350
out_bbox = convert_bounding_box_format(
336351
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
@@ -345,25 +360,25 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
345360
],
346361
extra_dims=((4,),),
347362
):
363+
bboxes_format = bboxes.format
364+
bboxes_image_size = bboxes.image_size
365+
348366
output_bboxes = F.affine_bounding_box(
349367
bboxes,
350-
bboxes.format,
351-
image_size=image_size,
368+
bboxes_format,
369+
image_size=bboxes_image_size,
352370
angle=angle,
353371
translate=(translate, translate),
354372
scale=scale,
355373
shear=(shear, shear),
356374
center=center,
357375
)
376+
358377
if center is None:
359-
center = [s // 2 for s in image_size[::-1]]
378+
center = [s // 2 for s in bboxes_image_size[::-1]]
360379

361-
bboxes_format = bboxes.format
362-
bboxes_image_size = bboxes.image_size
363380
if bboxes.ndim < 2:
364-
bboxes = [
365-
bboxes,
366-
]
381+
bboxes = [bboxes]
367382

368383
expected_bboxes = []
369384
for bbox in bboxes:
@@ -427,3 +442,147 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
427442
assert len(output_boxes) == len(expected_bboxes)
428443
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
429444
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
445+
446+
447+
@pytest.mark.parametrize("angle", range(-90, 90, 56))
448+
@pytest.mark.parametrize("expand", [True, False])
449+
@pytest.mark.parametrize("center", [None, (12, 14)])
450+
def test_correctness_rotate_bounding_box(angle, expand, center):
451+
def _compute_expected_bbox(bbox, angle_, expand_, center_):
452+
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
453+
affine_matrix = affine_matrix[:2, :]
454+
455+
image_size = bbox.image_size
456+
bbox_xyxy = convert_bounding_box_format(
457+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
458+
)
459+
points = np.array(
460+
[
461+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
462+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
463+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
464+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
465+
# image frame
466+
[0.0, 0.0, 1.0],
467+
[0.0, image_size[0], 1.0],
468+
[image_size[1], image_size[0], 1.0],
469+
[image_size[1], 0.0, 1.0],
470+
]
471+
)
472+
transformed_points = np.matmul(points, affine_matrix.T)
473+
out_bbox = [
474+
np.min(transformed_points[:4, 0]),
475+
np.min(transformed_points[:4, 1]),
476+
np.max(transformed_points[:4, 0]),
477+
np.max(transformed_points[:4, 1]),
478+
]
479+
if expand_:
480+
tr_x = np.min(transformed_points[4:, 0])
481+
tr_y = np.min(transformed_points[4:, 1])
482+
out_bbox[0] -= tr_x
483+
out_bbox[1] -= tr_y
484+
out_bbox[2] -= tr_x
485+
out_bbox[3] -= tr_y
486+
487+
out_bbox = features.BoundingBox(
488+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float32
489+
)
490+
out_bbox = convert_bounding_box_format(
491+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
492+
)
493+
return out_bbox.to(bbox.device)
494+
495+
image_size = (32, 38)
496+
497+
for bboxes in make_bounding_boxes(
498+
image_sizes=[
499+
image_size,
500+
],
501+
extra_dims=((4,),),
502+
):
503+
bboxes_format = bboxes.format
504+
bboxes_image_size = bboxes.image_size
505+
506+
output_bboxes = F.rotate_bounding_box(
507+
bboxes,
508+
bboxes_format,
509+
image_size=bboxes_image_size,
510+
angle=angle,
511+
expand=expand,
512+
center=center,
513+
)
514+
515+
if center is None:
516+
center = [s // 2 for s in bboxes_image_size[::-1]]
517+
518+
if bboxes.ndim < 2:
519+
bboxes = [bboxes]
520+
521+
expected_bboxes = []
522+
for bbox in bboxes:
523+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
524+
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center))
525+
if len(expected_bboxes) > 1:
526+
expected_bboxes = torch.stack(expected_bboxes)
527+
else:
528+
expected_bboxes = expected_bboxes[0]
529+
print("input:", bboxes)
530+
print("output_bboxes:", output_bboxes)
531+
print("expected_bboxes:", expected_bboxes)
532+
torch.testing.assert_close(output_bboxes, expected_bboxes)
533+
534+
535+
@pytest.mark.parametrize("device", cpu_and_gpu())
536+
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress
537+
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
538+
# Check transformation against known expected output
539+
image_size = (64, 64)
540+
# xyxy format
541+
in_boxes = [
542+
[1, 1, 5, 5],
543+
[1, image_size[0] - 6, 5, image_size[0] - 2],
544+
[image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2],
545+
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10],
546+
]
547+
in_boxes = features.BoundingBox(
548+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
549+
).to(device)
550+
# Tested parameters
551+
angle = 45
552+
center = None if expand else [12, 23]
553+
554+
# # Expected bboxes computed using Detectron2:
555+
# from detectron2.data.transforms import RotationTransform, AugmentationList
556+
# from detectron2.data.transforms import AugInput
557+
# import cv2
558+
# inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32"))
559+
# augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ])
560+
# out = augs(inpt)
561+
# print(inpt.boxes)
562+
if expand:
563+
expected_bboxes = [
564+
[1.65937957, 42.67157288, 7.31623382, 48.32842712],
565+
[41.96446609, 82.9766594, 47.62132034, 88.63351365],
566+
[82.26955262, 42.67157288, 87.92640687, 48.32842712],
567+
[31.35786438, 31.35786438, 59.64213562, 59.64213562],
568+
]
569+
else:
570+
expected_bboxes = [
571+
[-11.33452378, 12.39339828, -5.67766953, 18.05025253],
572+
[28.97056275, 52.69848481, 34.627417, 58.35533906],
573+
[69.27564928, 12.39339828, 74.93250353, 18.05025253],
574+
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
575+
]
576+
577+
output_boxes = F.rotate_bounding_box(
578+
in_boxes,
579+
in_boxes.format,
580+
in_boxes.image_size,
581+
angle,
582+
expand=expand,
583+
center=center,
584+
)
585+
586+
assert len(output_boxes) == len(expected_bboxes)
587+
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
588+
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
affine_bounding_box,
5353
affine_image_tensor,
5454
affine_image_pil,
55+
rotate_bounding_box,
5556
rotate_image_tensor,
5657
rotate_image_pil,
5758
pad_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 86 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
import warnings
23
from typing import Tuple, List, Optional, Sequence, Union
34

45
import PIL.Image
@@ -197,24 +198,28 @@ def affine_image_pil(
197198
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
198199

199200

200-
def affine_bounding_box(
201+
def _affine_bounding_box_xyxy(
201202
bounding_box: torch.Tensor,
202-
format: features.BoundingBoxFormat,
203203
image_size: Tuple[int, int],
204204
angle: float,
205-
translate: List[float],
206-
scale: float,
207-
shear: List[float],
205+
translate: Optional[List[float]] = None,
206+
scale: Optional[float] = None,
207+
shear: Optional[List[float]] = None,
208208
center: Optional[List[float]] = None,
209+
expand: bool = False,
209210
) -> torch.Tensor:
210-
original_shape = bounding_box.shape
211-
bounding_box = convert_bounding_box_format(
212-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
213-
).view(-1, 4)
214-
215211
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
216212
device = bounding_box.device
217213

214+
if translate is None:
215+
translate = [0.0, 0.0]
216+
217+
if scale is None:
218+
scale = 1.0
219+
220+
if shear is None:
221+
shear = [0.0, 0.0]
222+
218223
if center is None:
219224
height, width = image_size
220225
center_f = [width * 0.5, height * 0.5]
@@ -241,6 +246,47 @@ def affine_bounding_box(
241246
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
242247
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
243248
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
249+
250+
if expand:
251+
# Compute minimum point for transformed image frame:
252+
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
253+
height, width = image_size
254+
points = torch.tensor(
255+
[
256+
[0.0, 0.0, 1.0],
257+
[0.0, 1.0 * height, 1.0],
258+
[1.0 * width, 1.0 * height, 1.0],
259+
[1.0 * width, 0.0, 1.0],
260+
],
261+
dtype=dtype,
262+
device=device,
263+
)
264+
new_points = torch.matmul(points, affine_matrix.T)
265+
tr, _ = torch.min(new_points, dim=0, keepdim=True)
266+
# Translate bounding boxes
267+
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
268+
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]
269+
270+
return out_bboxes
271+
272+
273+
def affine_bounding_box(
274+
bounding_box: torch.Tensor,
275+
format: features.BoundingBoxFormat,
276+
image_size: Tuple[int, int],
277+
angle: float,
278+
translate: List[float],
279+
scale: float,
280+
shear: List[float],
281+
center: Optional[List[float]] = None,
282+
) -> torch.Tensor:
283+
original_shape = bounding_box.shape
284+
bounding_box = convert_bounding_box_format(
285+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
286+
).view(-1, 4)
287+
288+
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)
289+
244290
# out_bboxes should be of shape [N boxes, 4]
245291

246292
return convert_bounding_box_format(
@@ -258,9 +304,12 @@ def rotate_image_tensor(
258304
) -> torch.Tensor:
259305
center_f = [0.0, 0.0]
260306
if center is not None:
261-
_, height, width = get_dimensions_image_tensor(img)
262-
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
263-
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
307+
if expand:
308+
warnings.warn("The provided center argument is ignored if expand is True")
309+
else:
310+
_, height, width = get_dimensions_image_tensor(img)
311+
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
312+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
264313

265314
# due to current incoherence of rotation angle direction between affine and rotate implementations
266315
# we need to set -angle.
@@ -276,11 +325,35 @@ def rotate_image_pil(
276325
fill: Optional[List[float]] = None,
277326
center: Optional[List[float]] = None,
278327
) -> PIL.Image.Image:
328+
if center is not None and expand:
329+
warnings.warn("The provided center argument is ignored if expand is True")
330+
center = None
331+
279332
return _FP.rotate(
280333
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
281334
)
282335

283336

337+
def rotate_bounding_box(
338+
bounding_box: torch.Tensor,
339+
format: features.BoundingBoxFormat,
340+
image_size: Tuple[int, int],
341+
angle: float,
342+
expand: bool = False,
343+
center: Optional[List[float]] = None,
344+
) -> torch.Tensor:
345+
original_shape = bounding_box.shape
346+
bounding_box = convert_bounding_box_format(
347+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
348+
).view(-1, 4)
349+
350+
out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)
351+
352+
return convert_bounding_box_format(
353+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
354+
).view(original_shape)
355+
356+
284357
pad_image_tensor = _FT.pad
285358
pad_image_pil = _FP.pad
286359

0 commit comments

Comments
 (0)