Skip to content

Commit 23112f8

Browse files
authored
Added image_size computation for BoundingBox.rotate if expand (#6319)
* Added image_size computation for BoundingBox.rotate if expand * Added tests
1 parent a8f970e commit 23112f8

File tree

5 files changed

+32
-7
lines changed

5 files changed

+32
-7
lines changed

test/test_prototype_transforms.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -467,6 +467,20 @@ def test__transform(self, degrees, expand, fill, center, mocker):
467467

468468
fn.assert_called_once_with(inpt, **params, interpolation=interpolation, expand=expand, fill=fill, center=center)
469469

470+
@pytest.mark.parametrize("angle", [34, -87])
471+
@pytest.mark.parametrize("expand", [False, True])
472+
def test_boundingbox_image_size(self, angle, expand):
473+
# Specific test for BoundingBox.rotate
474+
bbox = features.BoundingBox(
475+
torch.tensor([1, 2, 3, 4]), format=features.BoundingBoxFormat.XYXY, image_size=(32, 32)
476+
)
477+
img = features.Image(torch.rand(1, 3, 32, 32))
478+
479+
out_img = img.rotate(angle, expand=expand)
480+
out_bbox = bbox.rotate(angle, expand=expand)
481+
482+
assert out_img.image_size == out_bbox.image_size
483+
470484

471485
class TestRandomAffine:
472486
def test_assertions(self):

test/test_prototype_transforms_functional.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,13 +693,11 @@ def test_scriptable(kernel):
693693
"InterpolationMode",
694694
"decode_video_with_av",
695695
"crop",
696-
"rotate",
697696
"perspective",
698697
"elastic_transform",
699698
"elastic",
700699
}
701700
# We skip 'crop' due to missing 'height' and 'width'
702-
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
703701
# We skip 'perspective' as it requires different input args than perspective_image_tensor etc
704702
# Skip 'elastic', TODO: inspect why test is failing
705703
],
@@ -999,6 +997,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
999997
out_bbox[2] -= tr_x
1000998
out_bbox[3] -= tr_y
1001999

1000+
# image_size should be updated, but it is OK here to skip its computation
1001+
# as we do not compute it in F.rotate_bounding_box
1002+
10021003
out_bbox = features.BoundingBox(
10031004
out_bbox,
10041005
format=features.BoundingBoxFormat.XYXY,

torchvision/prototype/features/_bounding_box.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import torch
66
from torchvision._utils import StrEnum
77
from torchvision.transforms import InterpolationMode
8+
from torchvision.transforms.functional import _get_inverse_affine_matrix
9+
from torchvision.transforms.functional_tensor import _compute_output_size
810

911
from ._feature import _Feature
1012

@@ -168,10 +170,18 @@ def rotate(
168170
output = _F.rotate_bounding_box(
169171
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
170172
)
171-
# TODO: update output image size if expand is True
173+
image_size = self.image_size
172174
if expand:
173-
raise RuntimeError("Not yet implemented")
174-
return BoundingBox.new_like(self, output, dtype=output.dtype)
175+
# The way we recompute image_size is not optimal due to redundant computations of
176+
# - rotation matrix (_get_inverse_affine_matrix)
177+
# - points dot matrix (_compute_output_size)
178+
# Alternatively, we could return new image size by _F.rotate_bounding_box
179+
height, width = image_size
180+
rotation_matrix = _get_inverse_affine_matrix([0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0])
181+
new_width, new_height = _compute_output_size(rotation_matrix, width, height)
182+
image_size = (new_height, new_width)
183+
184+
return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size)
175185

176186
def affine(
177187
self,

torchvision/prototype/features/_image.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def new_like(
7474

7575
@property
7676
def image_size(self) -> Tuple[int, int]:
77-
return cast(Tuple[int, int], self.shape[-2:])
77+
return cast(Tuple[int, int], tuple(self.shape[-2:]))
7878

7979
@property
8080
def num_channels(self) -> int:

torchvision/transforms/functional_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -634,7 +634,7 @@ def _compute_output_size(matrix: List[float], w: int, h: int) -> Tuple[int, int]
634634
cmax = torch.ceil((max_vals / tol).trunc_() * tol)
635635
cmin = torch.floor((min_vals / tol).trunc_() * tol)
636636
size = cmax - cmin
637-
return int(size[0]), int(size[1])
637+
return int(size[0]), int(size[1]) # w, h
638638

639639

640640
def rotate(

0 commit comments

Comments
 (0)