diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 2ba6b1115d7..86e4b88fe63 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -248,6 +248,21 @@ def affine_bounding_box(): ) +@register_kernel_info_from_sample_inputs_fn +def rotate_bounding_box(): + for bounding_box, angle, expand, center in itertools.product( + make_bounding_boxes(), [-87, 15, 90], [True, False], [None, [12, 23]] # angle # expand # center + ): + yield SampleInput( + bounding_box, + format=bounding_box.format, + image_size=bounding_box.image_size, + angle=angle, + expand=expand, + center=center, + ) + + @pytest.mark.parametrize( "kernel", [ @@ -330,7 +345,7 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): np.max(transformed_points[:, 1]), ] out_bbox = features.BoundingBox( - out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32 + out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=bbox.image_size, dtype=torch.float32 ) out_bbox = convert_bounding_box_format( out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False @@ -345,25 +360,25 @@ def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_): ], extra_dims=((4,),), ): + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + output_bboxes = F.affine_bounding_box( bboxes, - bboxes.format, - image_size=image_size, + bboxes_format, + image_size=bboxes_image_size, angle=angle, translate=(translate, translate), scale=scale, shear=(shear, shear), center=center, ) + if center is None: - center = [s // 2 for s in image_size[::-1]] + center = [s // 2 for s in bboxes_image_size[::-1]] - bboxes_format = bboxes.format - bboxes_image_size = bboxes.image_size if bboxes.ndim < 2: - bboxes = [ - bboxes, - ] + bboxes = [bboxes] expected_bboxes = [] for bbox in bboxes: @@ -427,3 +442,147 @@ def test_correctness_affine_bounding_box_on_fixed_input(device): assert len(output_boxes) == len(expected_bboxes) for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) + + +@pytest.mark.parametrize("angle", range(-90, 90, 56)) +@pytest.mark.parametrize("expand", [True, False]) +@pytest.mark.parametrize("center", [None, (12, 14)]) +def test_correctness_rotate_bounding_box(angle, expand, center): + def _compute_expected_bbox(bbox, angle_, expand_, center_): + affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_) + affine_matrix = affine_matrix[:2, :] + + image_size = bbox.image_size + bbox_xyxy = convert_bounding_box_format( + bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY + ) + points = np.array( + [ + [bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0], + [bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0], + [bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0], + # image frame + [0.0, 0.0, 1.0], + [0.0, image_size[0], 1.0], + [image_size[1], image_size[0], 1.0], + [image_size[1], 0.0, 1.0], + ] + ) + transformed_points = np.matmul(points, affine_matrix.T) + out_bbox = [ + np.min(transformed_points[:4, 0]), + np.min(transformed_points[:4, 1]), + np.max(transformed_points[:4, 0]), + np.max(transformed_points[:4, 1]), + ] + if expand_: + tr_x = np.min(transformed_points[4:, 0]) + tr_y = np.min(transformed_points[4:, 1]) + out_bbox[0] -= tr_x + out_bbox[1] -= tr_y + out_bbox[2] -= tr_x + out_bbox[3] -= tr_y + + out_bbox = features.BoundingBox( + out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float32 + ) + out_bbox = convert_bounding_box_format( + out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False + ) + return out_bbox.to(bbox.device) + + image_size = (32, 38) + + for bboxes in make_bounding_boxes( + image_sizes=[ + image_size, + ], + extra_dims=((4,),), + ): + bboxes_format = bboxes.format + bboxes_image_size = bboxes.image_size + + output_bboxes = F.rotate_bounding_box( + bboxes, + bboxes_format, + image_size=bboxes_image_size, + angle=angle, + expand=expand, + center=center, + ) + + if center is None: + center = [s // 2 for s in bboxes_image_size[::-1]] + + if bboxes.ndim < 2: + bboxes = [bboxes] + + expected_bboxes = [] + for bbox in bboxes: + bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size) + expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center)) + if len(expected_bboxes) > 1: + expected_bboxes = torch.stack(expected_bboxes) + else: + expected_bboxes = expected_bboxes[0] + print("input:", bboxes) + print("output_bboxes:", output_bboxes) + print("expected_bboxes:", expected_bboxes) + torch.testing.assert_close(output_bboxes, expected_bboxes) + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("expand", [False]) # expand=True does not match D2, analysis in progress +def test_correctness_rotate_bounding_box_on_fixed_input(device, expand): + # Check transformation against known expected output + image_size = (64, 64) + # xyxy format + in_boxes = [ + [1, 1, 5, 5], + [1, image_size[0] - 6, 5, image_size[0] - 2], + [image_size[1] - 6, image_size[0] - 6, image_size[1] - 2, image_size[0] - 2], + [image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10], + ] + in_boxes = features.BoundingBox( + in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64 + ).to(device) + # Tested parameters + angle = 45 + center = None if expand else [12, 23] + + # # Expected bboxes computed using Detectron2: + # from detectron2.data.transforms import RotationTransform, AugmentationList + # from detectron2.data.transforms import AugInput + # import cv2 + # inpt = AugInput(im1, boxes=np.array(in_boxes, dtype="float32")) + # augs = AugmentationList([RotationTransform(*size, angle, expand=expand, center=center, interp=cv2.INTER_NEAREST), ]) + # out = augs(inpt) + # print(inpt.boxes) + if expand: + expected_bboxes = [ + [1.65937957, 42.67157288, 7.31623382, 48.32842712], + [41.96446609, 82.9766594, 47.62132034, 88.63351365], + [82.26955262, 42.67157288, 87.92640687, 48.32842712], + [31.35786438, 31.35786438, 59.64213562, 59.64213562], + ] + else: + expected_bboxes = [ + [-11.33452378, 12.39339828, -5.67766953, 18.05025253], + [28.97056275, 52.69848481, 34.627417, 58.35533906], + [69.27564928, 12.39339828, 74.93250353, 18.05025253], + [18.36396103, 1.07968978, 46.64823228, 29.36396103], + ] + + output_boxes = F.rotate_bounding_box( + in_boxes, + in_boxes.format, + in_boxes.image_size, + angle, + expand=expand, + center=center, + ) + + assert len(output_boxes) == len(expected_bboxes) + for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()): + np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 469768ba9c2..ace1f585d82 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -52,6 +52,7 @@ affine_bounding_box, affine_image_tensor, affine_image_pil, + rotate_bounding_box, rotate_image_tensor, rotate_image_pil, pad_image_tensor, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index ecf0d31df3a..c3f294a8546 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1,4 +1,5 @@ import numbers +import warnings from typing import Tuple, List, Optional, Sequence, Union import PIL.Image @@ -197,24 +198,28 @@ def affine_image_pil( return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill) -def affine_bounding_box( +def _affine_bounding_box_xyxy( bounding_box: torch.Tensor, - format: features.BoundingBoxFormat, image_size: Tuple[int, int], angle: float, - translate: List[float], - scale: float, - shear: List[float], + translate: Optional[List[float]] = None, + scale: Optional[float] = None, + shear: Optional[List[float]] = None, center: Optional[List[float]] = None, + expand: bool = False, ) -> torch.Tensor: - original_shape = bounding_box.shape - bounding_box = convert_bounding_box_format( - bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY - ).view(-1, 4) - dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32 device = bounding_box.device + if translate is None: + translate = [0.0, 0.0] + + if scale is None: + scale = 1.0 + + if shear is None: + shear = [0.0, 0.0] + if center is None: height, width = image_size center_f = [width * 0.5, height * 0.5] @@ -241,6 +246,47 @@ def affine_bounding_box( out_bbox_mins, _ = torch.min(transformed_points, dim=1) out_bbox_maxs, _ = torch.max(transformed_points, dim=1) out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + if expand: + # Compute minimum point for transformed image frame: + # Points are Top-Left, Top-Right, Bottom-Left, Bottom-Right points. + height, width = image_size + points = torch.tensor( + [ + [0.0, 0.0, 1.0], + [0.0, 1.0 * height, 1.0], + [1.0 * width, 1.0 * height, 1.0], + [1.0 * width, 0.0, 1.0], + ], + dtype=dtype, + device=device, + ) + new_points = torch.matmul(points, affine_matrix.T) + tr, _ = torch.min(new_points, dim=0, keepdim=True) + # Translate bounding boxes + out_bboxes[:, 0::2] = out_bboxes[:, 0::2] - tr[:, 0] + out_bboxes[:, 1::2] = out_bboxes[:, 1::2] - tr[:, 1] + + return out_bboxes + + +def affine_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + image_size: Tuple[int, int], + angle: float, + translate: List[float], + scale: float, + shear: List[float], + center: Optional[List[float]] = None, +) -> torch.Tensor: + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle, translate, scale, shear, center) + # out_bboxes should be of shape [N boxes, 4] return convert_bounding_box_format( @@ -258,9 +304,12 @@ def rotate_image_tensor( ) -> torch.Tensor: center_f = [0.0, 0.0] if center is not None: - _, height, width = get_dimensions_image_tensor(img) - # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. - center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] + if expand: + warnings.warn("The provided center argument is ignored if expand is True") + else: + _, height, width = get_dimensions_image_tensor(img) + # Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center. + center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])] # due to current incoherence of rotation angle direction between affine and rotate implementations # we need to set -angle. @@ -276,11 +325,35 @@ def rotate_image_pil( fill: Optional[List[float]] = None, center: Optional[List[float]] = None, ) -> PIL.Image.Image: + if center is not None and expand: + warnings.warn("The provided center argument is ignored if expand is True") + center = None + return _FP.rotate( img, angle, interpolation=pil_modes_mapping[interpolation], expand=expand, fill=fill, center=center ) +def rotate_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + image_size: Tuple[int, int], + angle: float, + expand: bool = False, + center: Optional[List[float]] = None, +) -> torch.Tensor: + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + out_bboxes = _affine_bounding_box_xyxy(bounding_box, image_size, angle=-angle, center=center, expand=expand) + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) + + pad_image_tensor = _FT.pad pad_image_pil = _FP.pad