Skip to content

Commit 9bbb777

Browse files
authored
[proto] Added functional affine_bounding_box op (#5597)
* Added functional affine_bounding_box op with tests * Updated comments and added another test case * Update _geometry.py * Fixed device mismatch issue Added a cude/cpu test Reduced the number of test samples
1 parent 3aa2a93 commit 9bbb777

File tree

4 files changed

+267
-17
lines changed

4 files changed

+267
-17
lines changed

test/test_prototype_transforms_functional.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
import functools
22
import itertools
3+
import math
34

5+
import numpy as np
46
import pytest
57
import torch.testing
68
import torchvision.prototype.transforms.functional as F
9+
from common_utils import cpu_and_gpu
710
from torch import jit
811
from torch.nn.functional import one_hot
912
from torchvision.prototype import features
13+
from torchvision.prototype.transforms.functional._meta import convert_bounding_box_format
1014
from torchvision.transforms.functional_tensor import _max_value as get_max_value
1115

1216
make_tensor = functools.partial(torch.testing.make_tensor, device="cpu")
@@ -205,6 +209,45 @@ def resize_bounding_box():
205209
yield SampleInput(bounding_box, size=size, image_size=bounding_box.image_size)
206210

207211

212+
@register_kernel_info_from_sample_inputs_fn
213+
def affine_image_tensor():
214+
for image, angle, translate, scale, shear in itertools.product(
215+
make_images(extra_dims=()),
216+
[-87, 15, 90], # angle
217+
[5, -5], # translate
218+
[0.77, 1.27], # scale
219+
[0, 12], # shear
220+
):
221+
yield SampleInput(
222+
image,
223+
angle=angle,
224+
translate=(translate, translate),
225+
scale=scale,
226+
shear=(shear, shear),
227+
interpolation=F.InterpolationMode.NEAREST,
228+
)
229+
230+
231+
@register_kernel_info_from_sample_inputs_fn
232+
def affine_bounding_box():
233+
for bounding_box, angle, translate, scale, shear in itertools.product(
234+
make_bounding_boxes(),
235+
[-87, 15, 90], # angle
236+
[5, -5], # translate
237+
[0.77, 1.27], # scale
238+
[0, 12], # shear
239+
):
240+
yield SampleInput(
241+
bounding_box,
242+
format=bounding_box.format,
243+
image_size=bounding_box.image_size,
244+
angle=angle,
245+
translate=(translate, translate),
246+
scale=scale,
247+
shear=(shear, shear),
248+
)
249+
250+
208251
@pytest.mark.parametrize(
209252
"kernel",
210253
[
@@ -233,3 +276,154 @@ def test_eager_vs_scripted(functional_info, sample_input):
233276
scripted = jit.script(functional_info.functional)(*sample_input.args, **sample_input.kwargs)
234277

235278
torch.testing.assert_close(eager, scripted)
279+
280+
281+
def _compute_affine_matrix(angle_, translate_, scale_, shear_, center_):
282+
rot = math.radians(angle_)
283+
cx, cy = center_
284+
tx, ty = translate_
285+
sx, sy = [math.radians(sh_) for sh_ in shear_]
286+
287+
c_matrix = np.array([[1, 0, cx], [0, 1, cy], [0, 0, 1]])
288+
t_matrix = np.array([[1, 0, tx], [0, 1, ty], [0, 0, 1]])
289+
c_matrix_inv = np.linalg.inv(c_matrix)
290+
rs_matrix = np.array(
291+
[
292+
[scale_ * math.cos(rot), -scale_ * math.sin(rot), 0],
293+
[scale_ * math.sin(rot), scale_ * math.cos(rot), 0],
294+
[0, 0, 1],
295+
]
296+
)
297+
shear_x_matrix = np.array([[1, -math.tan(sx), 0], [0, 1, 0], [0, 0, 1]])
298+
shear_y_matrix = np.array([[1, 0, 0], [-math.tan(sy), 1, 0], [0, 0, 1]])
299+
rss_matrix = np.matmul(rs_matrix, np.matmul(shear_y_matrix, shear_x_matrix))
300+
true_matrix = np.matmul(t_matrix, np.matmul(c_matrix, np.matmul(rss_matrix, c_matrix_inv)))
301+
return true_matrix
302+
303+
304+
@pytest.mark.parametrize("angle", range(-90, 90, 56))
305+
@pytest.mark.parametrize("translate", range(-10, 10, 8))
306+
@pytest.mark.parametrize("scale", [0.77, 1.0, 1.27])
307+
@pytest.mark.parametrize("shear", range(-15, 15, 8))
308+
@pytest.mark.parametrize("center", [None, (12, 14)])
309+
def test_correctness_affine_bounding_box(angle, translate, scale, shear, center):
310+
def _compute_expected_bbox(bbox, angle_, translate_, scale_, shear_, center_):
311+
affine_matrix = _compute_affine_matrix(angle_, translate_, scale_, shear_, center_)
312+
affine_matrix = affine_matrix[:2, :]
313+
314+
bbox_xyxy = convert_bounding_box_format(
315+
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
316+
)
317+
points = np.array(
318+
[
319+
[bbox_xyxy[0].item(), bbox_xyxy[1].item(), 1.0],
320+
[bbox_xyxy[2].item(), bbox_xyxy[1].item(), 1.0],
321+
[bbox_xyxy[0].item(), bbox_xyxy[3].item(), 1.0],
322+
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
323+
]
324+
)
325+
transformed_points = np.matmul(points, affine_matrix.T)
326+
out_bbox = [
327+
np.min(transformed_points[:, 0]),
328+
np.min(transformed_points[:, 1]),
329+
np.max(transformed_points[:, 0]),
330+
np.max(transformed_points[:, 1]),
331+
]
332+
out_bbox = features.BoundingBox(
333+
out_bbox, format=features.BoundingBoxFormat.XYXY, image_size=(32, 32), dtype=torch.float32
334+
)
335+
out_bbox = convert_bounding_box_format(
336+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
337+
)
338+
return out_bbox.to(bbox.device)
339+
340+
image_size = (32, 38)
341+
342+
for bboxes in make_bounding_boxes(
343+
image_sizes=[
344+
image_size,
345+
],
346+
extra_dims=((4,),),
347+
):
348+
output_bboxes = F.affine_bounding_box(
349+
bboxes,
350+
bboxes.format,
351+
image_size=image_size,
352+
angle=angle,
353+
translate=(translate, translate),
354+
scale=scale,
355+
shear=(shear, shear),
356+
center=center,
357+
)
358+
if center is None:
359+
center = [s // 2 for s in image_size[::-1]]
360+
361+
bboxes_format = bboxes.format
362+
bboxes_image_size = bboxes.image_size
363+
if bboxes.ndim < 2:
364+
bboxes = [
365+
bboxes,
366+
]
367+
368+
expected_bboxes = []
369+
for bbox in bboxes:
370+
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
371+
expected_bboxes.append(
372+
_compute_expected_bbox(bbox, angle, (translate, translate), scale, (shear, shear), center)
373+
)
374+
if len(expected_bboxes) > 1:
375+
expected_bboxes = torch.stack(expected_bboxes)
376+
else:
377+
expected_bboxes = expected_bboxes[0]
378+
torch.testing.assert_close(output_bboxes, expected_bboxes)
379+
380+
381+
@pytest.mark.parametrize("device", cpu_and_gpu())
382+
def test_correctness_affine_bounding_box_on_fixed_input(device):
383+
# Check transformation against known expected output
384+
image_size = (64, 64)
385+
# xyxy format
386+
in_boxes = [
387+
[20, 25, 35, 45],
388+
[50, 5, 70, 22],
389+
[image_size[1] // 2 - 10, image_size[0] // 2 - 10, image_size[1] // 2 + 10, image_size[0] // 2 + 10],
390+
[1, 1, 5, 5],
391+
]
392+
in_boxes = features.BoundingBox(
393+
in_boxes, format=features.BoundingBoxFormat.XYXY, image_size=image_size, dtype=torch.float64
394+
).to(device)
395+
# Tested parameters
396+
angle = 63
397+
scale = 0.89
398+
dx = 0.12
399+
dy = 0.23
400+
401+
# Expected bboxes computed using albumentations:
402+
# from albumentations.augmentations.geometric.functional import bbox_shift_scale_rotate
403+
# from albumentations.augmentations.geometric.functional import normalize_bbox, denormalize_bbox
404+
# expected_bboxes = []
405+
# for in_box in in_boxes:
406+
# n_in_box = normalize_bbox(in_box, *image_size)
407+
# n_out_box = bbox_shift_scale_rotate(n_in_box, -angle, scale, dx, dy, *image_size)
408+
# out_box = denormalize_bbox(n_out_box, *image_size)
409+
# expected_bboxes.append(out_box)
410+
expected_bboxes = [
411+
(24.522435977922218, 34.375689508290854, 46.443125279998114, 54.3516575015695),
412+
(54.88288587110401, 50.08453280875634, 76.44484547743795, 72.81332520036864),
413+
(27.709526487041554, 34.74952648704156, 51.650473512958435, 58.69047351295844),
414+
(48.56528888843238, 9.611532109828834, 53.35347829361575, 14.39972151501221),
415+
]
416+
417+
output_boxes = F.affine_bounding_box(
418+
in_boxes,
419+
in_boxes.format,
420+
in_boxes.image_size,
421+
angle,
422+
(dx * image_size[1], dy * image_size[0]),
423+
scale,
424+
shear=(0, 0),
425+
)
426+
427+
assert len(output_boxes) == len(expected_bboxes)
428+
for a_out_box, out_box in zip(expected_bboxes, output_boxes.cpu()):
429+
np.testing.assert_allclose(out_box.cpu().numpy(), a_out_box)

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
center_crop_image_pil,
5050
resized_crop_image_tensor,
5151
resized_crop_image_pil,
52+
affine_bounding_box,
5253
affine_image_tensor,
5354
affine_image_pil,
5455
rotate_image_tensor,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,57 @@ def affine_image_pil(
178178
return _FP.affine(img, matrix, interpolation=pil_modes_mapping[interpolation], fill=fill)
179179

180180

181+
def affine_bounding_box(
182+
bounding_box: torch.Tensor,
183+
format: features.BoundingBoxFormat,
184+
image_size: Tuple[int, int],
185+
angle: float,
186+
translate: List[float],
187+
scale: float,
188+
shear: List[float],
189+
center: Optional[List[float]] = None,
190+
) -> torch.Tensor:
191+
original_shape = bounding_box.shape
192+
bounding_box = convert_bounding_box_format(
193+
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
194+
).view(-1, 4)
195+
196+
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
197+
device = bounding_box.device
198+
199+
if center is None:
200+
height, width = image_size
201+
center_f = [width * 0.5, height * 0.5]
202+
else:
203+
center_f = [float(c) for c in center]
204+
205+
translate_f = [float(t) for t in translate]
206+
affine_matrix = torch.tensor(
207+
_get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear, inverted=False),
208+
dtype=dtype,
209+
device=device,
210+
).view(2, 3)
211+
# 1) Let's transform bboxes into a tensor of 4 points (top-left, top-right, bottom-left, bottom-right corners).
212+
# Tensor of points has shape (N * 4, 3), where N is the number of bboxes
213+
# Single point structure is similar to
214+
# [(xmin, ymin, 1), (xmax, ymin, 1), (xmax, ymax, 1), (xmin, ymax, 1)]
215+
points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2)
216+
points = torch.cat([points, torch.ones(points.shape[0], 1, device=points.device)], dim=-1)
217+
# 2) Now let's transform the points using affine matrix
218+
transformed_points = torch.matmul(points, affine_matrix.T)
219+
# 3) Reshape transformed points to [N boxes, 4 points, x/y coords]
220+
# and compute bounding box from 4 transformed points:
221+
transformed_points = transformed_points.view(-1, 4, 2)
222+
out_bbox_mins, _ = torch.min(transformed_points, dim=1)
223+
out_bbox_maxs, _ = torch.max(transformed_points, dim=1)
224+
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1)
225+
# out_bboxes should be of shape [N boxes, 4]
226+
227+
return convert_bounding_box_format(
228+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
229+
).view(original_shape)
230+
231+
181232
def rotate_image_tensor(
182233
img: torch.Tensor,
183234
angle: float,

torchvision/transforms/functional.py

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -931,11 +931,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
931931

932932

933933
def _get_inverse_affine_matrix(
934-
center: List[float],
935-
angle: float,
936-
translate: List[float],
937-
scale: float,
938-
shear: List[float],
934+
center: List[float], angle: float, translate: List[float], scale: float, shear: List[float], inverted: bool = True
939935
) -> List[float]:
940936
# Helper method to compute inverse matrix for affine transformation
941937

@@ -970,18 +966,26 @@ def _get_inverse_affine_matrix(
970966
c = math.sin(rot - sy) / math.cos(sy)
971967
d = -math.sin(rot - sy) * math.tan(sx) / math.cos(sy) + math.cos(rot)
972968

973-
# Inverted rotation matrix with scale and shear
974-
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
975-
matrix = [d, -b, 0.0, -c, a, 0.0]
976-
matrix = [x / scale for x in matrix]
977-
978-
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
979-
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
980-
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
981-
982-
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
983-
matrix[2] += cx
984-
matrix[5] += cy
969+
if inverted:
970+
# Inverted rotation matrix with scale and shear
971+
# det([[a, b], [c, d]]) == 1, since det(rotation) = 1 and det(shear) = 1
972+
matrix = [d, -b, 0.0, -c, a, 0.0]
973+
matrix = [x / scale for x in matrix]
974+
# Apply inverse of translation and of center translation: RSS^-1 * C^-1 * T^-1
975+
matrix[2] += matrix[0] * (-cx - tx) + matrix[1] * (-cy - ty)
976+
matrix[5] += matrix[3] * (-cx - tx) + matrix[4] * (-cy - ty)
977+
# Apply center translation: C * RSS^-1 * C^-1 * T^-1
978+
matrix[2] += cx
979+
matrix[5] += cy
980+
else:
981+
matrix = [a, b, 0.0, c, d, 0.0]
982+
matrix = [x * scale for x in matrix]
983+
# Apply inverse of center translation: RSS * C^-1
984+
matrix[2] += matrix[0] * (-cx) + matrix[1] * (-cy)
985+
matrix[5] += matrix[3] * (-cx) + matrix[4] * (-cy)
986+
# Apply translation and center : T * C * RSS * C^-1
987+
matrix[2] += cx + tx
988+
matrix[5] += cy + ty
985989

986990
return matrix
987991

0 commit comments

Comments
 (0)