Skip to content

[proto] Added functional rotate_bounding_box op #5638

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 168 additions & 9 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ def affine_bounding_box():
)


@register_kernel_info_from_sample_inputs_fn
def rotate_bounding_box():
for bounding_box, angle, expand, center in itertools.product(
make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center
):
yield SampleInput(
bounding_box,
format=bounding_box.format,
image_size=bounding_box.image_size,
angle=angle,
expand=expand,
center=center,
)


@pytest.mark.parametrize(
"kernel",
[
Expand Down Expand Up @@ -330,7 +345,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
np.max(transformed_points[:, 1]),
]
out_bbox = features.BoundingBox(
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=bbox.image_size, dtype=torch.float32
)
out_bbox = convert_bounding_box_format(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
Expand All @@ -345,25 +360,25 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
],
extra_dims=((4,),),
):
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size

output_bboxes = F.affine_bounding_box(
bboxes,
bboxes.format,
image_size=image_size,
bboxes_format,
image_size=bboxes_image_size,
angle=angle,
translate=(translate, translate),
scale=scale,
shear=(shear, shear),
center=center,
)

if center is None:
center = [s // 2 for s in image_size[::-1]]
center = [s // 2 for s in bboxes_image_size[::-1]]

bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size
if bboxes.ndim < 2:
bboxes = [
bboxes,
]
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
Expand Down Expand Up @@ -427,3 +442,147 @@ def test_correctness_affine_bounding_box_on_fixed_input(device):
assert len(output_boxes) == len(expected_bboxes)
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)


@pytest.mark.parametrize("angle", range(-90, 90, 56))
@pytest.mark.parametrize("expand", [True, False])
@pytest.mark.parametrize("center", [None, (12, 14)])
def test_correctness_rotate_bounding_box(angle, expand, center):
def _compute_expected_bbox(bbox, angle_, expand_, center_):
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
affine_matrix = affine_matrix[:2, :]

image_size = bbox.image_size
bbox_xyxy = convert_bounding_box_format(
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
)
points = np.array(
[
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
# image frame
[0.0, 0.0, 1.0],
[0.0, image_size[0], 1.0],
[image_size[1], image_size[0], 1.0],
[image_size[1], 0.0, 1.0],
]
)
transformed_points = np.matmul(points, affine_matrix.T)
out_bbox = [
np.min(transformed_points[:4, 0]),
np.min(transformed_points[:4, 1]),
np.max(transformed_points[:4, 0]),
np.max(transformed_points[:4, 1]),
]
if expand_:
tr_x = np.min(transformed_points[4:, 0])
tr_y = np.min(transformed_points[4:, 1])
out_bbox[0] -= tr_x
out_bbox[1] -= tr_y
out_bbox[2] -= tr_x
out_bbox[3] -= tr_y

out_bbox = features.BoundingBox(
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float32
)
out_bbox = convert_bounding_box_format(
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
)
return out_bbox.to(bbox.device)

image_size = (32, 38)

for bboxes in make_bounding_boxes(
image_sizes=[
image_size,
],
extra_dims=((4,),),
):
bboxes_format = bboxes.format
bboxes_image_size = bboxes.image_size

output_bboxes = F.rotate_bounding_box(
bboxes,
bboxes_format,
image_size=bboxes_image_size,
angle=angle,
expand=expand,
center=center,
)

if center is None:
center = [s // 2 for s in bboxes_image_size[::-1]]

if bboxes.ndim < 2:
bboxes = [bboxes]

expected_bboxes = []
for bbox in bboxes:
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center))
if len(expected_bboxes) > 1:
expected_bboxes = torch.stack(expected_bboxes)
else:
expected_bboxes = expected_bboxes[0]
print("input:", bboxes)
print("output_bboxes:", output_bboxes)
print("expected_bboxes:", expected_bboxes)
torch.testing.assert_close(output_bboxes, expected_bboxes)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress
def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
# Check transformation against known expected output
image_size = (64, 64)
# xyxy format
in_boxes = [
[1, 1, 5, 5],
[1, image_size[0] - 6, 5, image_size[0] - 2],
[image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2],
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10],
]
in_boxes = features.BoundingBox(
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
).to(device)
# Tested parameters
angle = 45
center = None if expand else [12, 23]

# # Expected bboxes computed using Detectron2:
# from detectron2.data.transforms import RotationTransform, AugmentationList
# from detectron2.data.transforms import AugInput
# import cv2
# inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32"))
# augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ])
# out = augs(inpt)
# print(inpt.boxes)
if expand:
expected_bboxes = [
[1.65937957, 42.67157288, 7.31623382, 48.32842712],
[41.96446609, 82.9766594, 47.62132034, 88.63351365],
[82.26955262, 42.67157288, 87.92640687, 48.32842712],
[31.35786438, 31.35786438, 59.64213562, 59.64213562],
]
else:
expected_bboxes = [
[-11.33452378, 12.39339828, -5.67766953, 18.05025253],
[28.97056275, 52.69848481, 34.627417, 58.35533906],
[69.27564928, 12.39339828, 74.93250353, 18.05025253],
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
]

output_boxes = F.rotate_bounding_box(
in_boxes,
in_boxes.format,
in_boxes.image_size,
angle,
expand=expand,
center=center,
)

assert len(output_boxes) == len(expected_bboxes)
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)
1 change: 1 addition & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
affine_bounding_box,
affine_image_tensor,
affine_image_pil,
rotate_bounding_box,
rotate_image_tensor,
rotate_image_pil,
pad_image_tensor,
Expand Down
99 changes: 86 additions & 13 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
import warnings
from typing import Tuple, List, Optional, Sequence, Union

import PIL.Image
Expand Down Expand Up @@ -197,24 +198,28 @@ def affine_image_pil(
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)


def affine_bounding_box(
def _affine_bounding_box_xyxy(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: float,
translate: List[float],
scale: float,
shear: List[float],
translate: Optional[List[float]] = None,
scale: Optional[float] = None,
shear: Optional[List[float]] = None,
center: Optional[List[float]] = None,
expand: bool = False,
) -> torch.Tensor:
original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
device = bounding_box.device

if translate is None:
translate = [0.0, 0.0]

if scale is None:
scale = 1.0

if shear is None:
shear = [0.0, 0.0]

if center is None:
height, width = image_size
center_f = [width * 0.5, height * 0.5]
Expand All @@ -241,6 +246,47 @@ def affine_bounding_box(
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)

if expand:
# Compute minimum point for transformed image frame:
# Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points.
height, width = image_size
points = torch.tensor(
[
[0.0, 0.0, 1.0],
[0.0, 1.0 * height, 1.0],
[1.0 * width, 1.0 * height, 1.0],
[1.0 * width, 0.0, 1.0],
],
dtype=dtype,
device=device,
)
new_points = torch.matmul(points, affine_matrix.T)
tr, _ = torch.min(new_points, dim=0, keepdim=True)
# Translate bounding boxes
out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0]
out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1]

return out_bboxes


def affine_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: float,
translate: List[float],
scale: float,
shear: List[float],
center: Optional[List[float]] = None,
) -> torch.Tensor:
original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center)

# out_bboxes should be of shape [N boxes, 4]

return convert_bounding_box_format(
Expand All @@ -258,9 +304,12 @@ def rotate_image_tensor(
) -> torch.Tensor:
center_f = [0.0, 0.0]
if center is not None:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
if expand:
warnings.warn("The provided center argument is ignored if expand is True")
else:
_, height, width = get_dimensions_image_tensor(img)
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]

# due to current incoherence of rotation angle direction between affine and rotate implementations
# we need to set -angle.
Expand All @@ -276,11 +325,35 @@ def rotate_image_pil(
fill: Optional[List[float]] = None,
center: Optional[List[float]] = None,
) -> PIL.Image.Image:
if center is not None and expand:
warnings.warn("The provided center argument is ignored if expand is True")
center = None

return _FP.rotate(
img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center
)


def rotate_bounding_box(
bounding_box: torch.Tensor,
format: features.BoundingBoxFormat,
image_size: Tuple[int, int],
angle: float,
expand: bool = False,
center: Optional[List[float]] = None,
) -> torch.Tensor:
original_shape = bounding_box.shape
bounding_box = convert_bounding_box_format(
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
).view(-1, 4)

out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand)

return convert_bounding_box_format(
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
).view(original_shape)


pad_image_tensor = _FT.pad
pad_image_pil = _FP.pad

Expand Down