Skip to content

Commit 234f113

Browse files
committed
Added functional affine_bounding_box op with tests
1 parent 7bb8186 commit 234f113

File tree

4 files changed

+207
-17
lines changed

4 files changed

+207
-17
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import functools
22
import itertools
3+
import math
34

5+
import numpy as np
46
import pytest
57
import torch.testing
68
import torchvision.prototype.transforms.functional as F
79
from torch import jit
810
from torch.nn.functional import one_hot
911
from torchvision.prototype import features
12+
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
1013
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1114

15+
1216
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
1317

1418

@@ -205,6 +209,45 @@ def resize_bounding_box():
205209
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
206210

207211

212+
@register_kernel_info_from_sample_inputs_fn
213+
def affine_image_tensor():
214+
for image, angle, translate, scale, shear in itertools.product(
215+
make_images(extra_dims=()),
216+
[-87, 15, 90], # angle
217+
[5, -5], # translate
218+
[0.77, 1.27], # scale
219+
[0, 12], # shear
220+
):
221+
yield SampleInput(
222+
image,
223+
angle=angle,
224+
translate=(translate, translate),
225+
scale=scale,
226+
shear=(shear, shear),
227+
interpolation=F.InterpolationMode.NEAREST,
228+
)
229+
230+
231+
@register_kernel_info_from_sample_inputs_fn
232+
def affine_bounding_box():
233+
for bounding_box, angle, translate, scale, shear in itertools.product(
234+
make_bounding_boxes(),
235+
[-87, 15, 90], # angle
236+
[5, -5], # translate
237+
[0.77, 1.27], # scale
238+
[0, 12], # shear
239+
):
240+
yield SampleInput(
241+
bounding_box,
242+
format=bounding_box.format,
243+
image_size=bounding_box.image_size,
244+
angle=angle,
245+
translate=(translate, translate),
246+
scale=scale,
247+
shear=(shear, shear),
248+
)
249+
250+
208251
@pytest.mark.parametrize(
209252
"kernel",
210253
[
@@ -233,3 +276,98 @@ def test_eager_vs_scripted(functional_info, sample_input):
233276
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
234277

235278
torch.testing.assert_close(eager, scripted)
279+
280+
281+
@pytest.mark.parametrize("angle", range(-90, 90, 36))
282+
@pytest.mark.parametrize("translate", range(-10, 10, 5))
283+
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
284+
@pytest.mark.parametrize("shear", range(-15, 15, 5))
285+
@pytest.mark.parametrize("center", [None, (12, 14)])
286+
def test_correctness_affine_bounding_box(angle, translate, scale, shear, center):
287+
def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
288+
rot = math.radians(angle_)
289+
cx, cy = center_
290+
tx, ty = translate_
291+
sx, sy = [math.radians(sh_) for sh_ in shear_]
292+
293+
c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
294+
t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
295+
c_matrix_inv = np.linalg.inv(c_matrix)
296+
rs_matrix = np.array(
297+
[
298+
[scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
299+
[scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
300+
[0, 0, 1],
301+
]
302+
)
303+
shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
304+
shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
305+
rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
306+
true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
307+
true_matrix = true_matrix[:2, :]
308+
309+
bbox_xyxy = convert_bounding_box_format(
310+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
311+
)
312+
points = np.array(
313+
[
314+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
315+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
316+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
317+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
318+
]
319+
)
320+
transformed_points = points @ true_matrix.T
321+
out_bbox = [
322+
np.min(transformed_points[:, 0]),
323+
np.min(transformed_points[:, 1]),
324+
np.max(transformed_points[:, 0]),
325+
np.max(transformed_points[:, 1]),
326+
]
327+
out_bbox = features.BoundingBox(
328+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32
329+
)
330+
out_bbox = convert_bounding_box_format(
331+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format
332+
)
333+
return out_bbox
334+
335+
image_size = (32, 32)
336+
337+
for bboxes in make_bounding_boxes(
338+
image_sizes=[
339+
image_size,
340+
],
341+
extra_dims=((4,),),
342+
):
343+
output_bboxes = F.affine_bounding_box(
344+
bboxes,
345+
bboxes.format,
346+
image_size=image_size,
347+
angle=angle,
348+
translate=(translate, translate),
349+
scale=scale,
350+
shear=(shear, shear),
351+
center=center,
352+
)
353+
if center is None:
354+
center = [s // 2 for s in image_size]
355+
356+
bboxes_format = bboxes.format
357+
bboxes_image_size = bboxes.image_size
358+
if bboxes.ndim < 2:
359+
bboxes = [
360+
bboxes,
361+
]
362+
363+
expected_bboxes = []
364+
for bbox in bboxes:
365+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
366+
expected_bboxes.append(
367+
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center)
368+
)
369+
expected_bboxes = torch.stack(expected_bboxes)
370+
if expected_bboxes.shape[0] < 2:
371+
expected_bboxes = expected_bboxes.squeeze(0)
372+
373+
torch.testing.assert_close(output_bboxes, expected_bboxes)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
center_crop_image_pil,
4949
resized_crop_image_tensor,
5050
resized_crop_image_pil,
51+
affine_bounding_box,
5152
affine_image_tensor,
5253
affine_image_pil,
5354
rotate_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,53 @@ def affine_image_pil(
174174
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
175175

176176

177+
def affine_bounding_box(
178+
bounding_box: torch.Tensor,
179+
format: features.BoundingBoxFormat,
180+
image_size: Tuple[int, int],
181+
angle: float,
182+
translate: List[float],
183+
scale: float,
184+
shear: List[float],
185+
center: Optional[List[float]] = None,
186+
) -> torch.Tensor:
187+
original_shape = bounding_box.shape
188+
bounding_box = convert_bounding_box_format(
189+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
190+
).view(-1, 4)
191+
192+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
193+
device = bounding_box.device
194+
195+
if center is None:
196+
height, width = image_size
197+
center_f = [width * 0.5, height * 0.5]
198+
else:
199+
center_f = [float(c) for c in center]
200+
201+
translate_f = [float(t) for t in translate]
202+
affine_matrix = torch.tensor(
203+
_get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False),
204+
dtype=dtype,
205+
device=device,
206+
).view(2, 3)
207+
# bboxes to 4 points like:
208+
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1), ...]
209+
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
210+
points = torch.cat([points, torch.ones(points.shape[0], 1)], dim=-1)
211+
transformed_points = points @ affine_matrix.T
212+
# reshape transformed points to [N boxes, 4 points, x/y coords]
213+
transformed_points = transformed_points.view(-1, 4, 2)
214+
# compute bounding box from 4 transformed points:
215+
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
216+
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
217+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
218+
# out_bboxes should be of shape [N boxes, 4]
219+
return convert_bounding_box_format(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).view(
220+
original_shape
221+
)
222+
223+
177224
def rotate_image_tensor(
178225
img: torch.Tensor,
179226
angle: float,

torchvision/transforms/functional.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -931,11 +931,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
931931

932932

933933
def _get_inverse_affine_matrix(
934-
center: List[float],
935-
angle: float,
936-
translate: List[float],
937-
scale: float,
938-
shear: List[float],
934+
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
939935
) -> List[float]:
940936
# Helper method to compute inverse matrix for affine transformation
941937

@@ -970,18 +966,26 @@ def _get_inverse_affine_matrix(
970966
c = math.sin(rot - sy) / math.cos(sy)
971967
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
972968

973-
# Inverted rotation matrix with scale and shear
974-
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
975-
matrix = [d, -b, 0.0, -c, a, 0.0]
976-
matrix = [x / scale for x in matrix]
977-
978-
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
979-
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
980-
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
981-
982-
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
983-
matrix[2] += cx
984-
matrix[5] += cy
969+
if inverted:
970+
# Inverted rotation matrix with scale and shear
971+
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
972+
matrix = [d, -b, 0.0, -c, a, 0.0]
973+
matrix = [x / scale for x in matrix]
974+
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
975+
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
976+
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
977+
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
978+
matrix[2] += cx
979+
matrix[5] += cy
980+
else:
981+
matrix = [a, b, 0.0, c, d, 0.0]
982+
matrix = [x * scale for x in matrix]
983+
# Apply inverse of center translation: RSS * C^-1
984+
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
985+
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
986+
# Apply translation and center : T * C * RSS * C^-1
987+
matrix[2] += cx + tx
988+
matrix[5] += cy + ty
985989

986990
return matrix
987991

0 commit comments

Comments
 (0)