Skip to content

Commit 428a54c

Browse files
Rotated bboxes transforms (#9084)
Co-authored-by: Nicolas Hug <[email protected]>
1 parent 297815a commit 428a54c

File tree

12 files changed

+324
-59
lines changed

12 files changed

+324
-59
lines changed
2.9 KB
Loading

test/common_utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -444,13 +444,13 @@ def sample_position(values, max_value):
444444
r_rad = r * torch.pi / 180.0
445445
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
446446
x1, y1 = x, y
447-
x3 = x1 + w * cos
448-
y3 = y1 - w * sin
449-
x2 = x3 + h * sin
450-
y2 = y3 + h * cos
447+
x2 = x1 + w * cos
448+
y2 = y1 - w * sin
449+
x3 = x2 + h * sin
450+
y3 = y2 + h * cos
451451
x4 = x1 + h * sin
452452
y4 = y1 + h * cos
453-
parts = (x1, y1, x3, y3, x2, y2, x4, y4)
453+
parts = (x1, y1, x2, y2, x3, y3, x4, y4)
454454
else:
455455
raise ValueError(f"Format {format} is not supported")
456456

test/test_transforms_v2.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -560,6 +560,78 @@ def affine_bounding_boxes(bounding_boxes):
560560
)
561561

562562

563+
def reference_affine_rotated_bounding_boxes_helper(bounding_boxes, *, affine_matrix, new_canvas_size=None, clamp=True):
564+
format = bounding_boxes.format
565+
canvas_size = new_canvas_size or bounding_boxes.canvas_size
566+
567+
def affine_rotated_bounding_boxes(bounding_boxes):
568+
dtype = bounding_boxes.dtype
569+
device = bounding_boxes.device
570+
571+
# Go to float before converting to prevent precision loss in case of CXCYWHR -> XYXYXYXY and W or H is 1
572+
input_xyxyxyxy = F.convert_bounding_box_format(
573+
bounding_boxes.to(dtype=torch.float64, device="cpu", copy=True),
574+
old_format=format,
575+
new_format=tv_tensors.BoundingBoxFormat.XYXYXYXY,
576+
inplace=True,
577+
)
578+
x1, y1, x2, y2, x3, y3, x4, y4 = input_xyxyxyxy.squeeze(0).tolist()
579+
580+
points = np.array(
581+
[
582+
[x1, y1, 1.0],
583+
[x2, y2, 1.0],
584+
[x3, y3, 1.0],
585+
[x4, y4, 1.0],
586+
]
587+
)
588+
transformed_points = np.matmul(points, affine_matrix.astype(points.dtype).T)
589+
output = torch.tensor(
590+
[
591+
float(transformed_points[1, 0]),
592+
float(transformed_points[1, 1]),
593+
float(transformed_points[0, 0]),
594+
float(transformed_points[0, 1]),
595+
float(transformed_points[3, 0]),
596+
float(transformed_points[3, 1]),
597+
float(transformed_points[2, 0]),
598+
float(transformed_points[2, 1]),
599+
]
600+
)
601+
602+
output = F.convert_bounding_box_format(
603+
output, old_format=tv_tensors.BoundingBoxFormat.XYXYXYXY, new_format=format
604+
)
605+
606+
if clamp:
607+
# It is important to clamp before casting, especially for CXCYWHR format, dtype=int64
608+
output = F.clamp_bounding_boxes(
609+
output,
610+
format=format,
611+
canvas_size=canvas_size,
612+
)
613+
else:
614+
# We leave the bounding box as float32 so the caller gets the full precision to perform any additional
615+
# operation
616+
dtype = output.dtype
617+
618+
return output.to(dtype=dtype, device=device)
619+
620+
return tv_tensors.BoundingBoxes(
621+
torch.cat(
622+
[
623+
affine_rotated_bounding_boxes(b)
624+
for b in bounding_boxes.reshape(
625+
-1, 5 if format != tv_tensors.BoundingBoxFormat.XYXYXYXY else 8
626+
).unbind()
627+
],
628+
dim=0,
629+
).reshape(bounding_boxes.shape),
630+
format=format,
631+
canvas_size=canvas_size,
632+
)
633+
634+
563635
class TestResize:
564636
INPUT_SIZE = (17, 11)
565637
OUTPUT_SIZES = [17, [17], (17,), None, [12, 13], (12, 13)]
@@ -1012,7 +1084,7 @@ class TestHorizontalFlip:
10121084
def test_kernel_image(self, dtype, device):
10131085
check_kernel(F.horizontal_flip_image, make_image(dtype=dtype, device=device))
10141086

1015-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1087+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
10161088
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
10171089
@pytest.mark.parametrize("device", cpu_and_cuda())
10181090
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1071,17 +1143,22 @@ def test_image_correctness(self, fn):
10711143

10721144
torch.testing.assert_close(actual, expected)
10731145

1074-
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes):
1146+
def _reference_horizontal_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
10751147
affine_matrix = np.array(
10761148
[
10771149
[-1, 0, bounding_boxes.canvas_size[1]],
10781150
[0, 1, 0],
10791151
],
10801152
)
10811153

1082-
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1154+
helper = (
1155+
reference_affine_rotated_bounding_boxes_helper
1156+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1157+
else reference_affine_bounding_boxes_helper
1158+
)
1159+
return helper(bounding_boxes, affine_matrix=affine_matrix)
10831160

1084-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1161+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
10851162
@pytest.mark.parametrize(
10861163
"fn", [F.horizontal_flip, transform_cls_to_functional(transforms.RandomHorizontalFlip, p=1)]
10871164
)
@@ -1464,7 +1541,7 @@ class TestVerticalFlip:
14641541
def test_kernel_image(self, dtype, device):
14651542
check_kernel(F.vertical_flip_image, make_image(dtype=dtype, device=device))
14661543

1467-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1544+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
14681545
@pytest.mark.parametrize("dtype", [torch.float32, torch.int64])
14691546
@pytest.mark.parametrize("device", cpu_and_cuda())
14701547
def test_kernel_bounding_boxes(self, format, dtype, device):
@@ -1521,17 +1598,22 @@ def test_image_correctness(self, fn):
15211598

15221599
torch.testing.assert_close(actual, expected)
15231600

1524-
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes):
1601+
def _reference_vertical_flip_bounding_boxes(self, bounding_boxes: tv_tensors.BoundingBoxes):
15251602
affine_matrix = np.array(
15261603
[
15271604
[1, 0, 0],
15281605
[0, -1, bounding_boxes.canvas_size[0]],
15291606
],
15301607
)
15311608

1532-
return reference_affine_bounding_boxes_helper(bounding_boxes, affine_matrix=affine_matrix)
1609+
helper = (
1610+
reference_affine_rotated_bounding_boxes_helper
1611+
if tv_tensors.is_rotated_bounding_format(bounding_boxes.format)
1612+
else reference_affine_bounding_boxes_helper
1613+
)
1614+
return helper(bounding_boxes, affine_matrix=affine_matrix)
15331615

1534-
@pytest.mark.parametrize("format", SUPPORTED_BOX_FORMATS)
1616+
@pytest.mark.parametrize("format", list(tv_tensors.BoundingBoxFormat))
15351617
@pytest.mark.parametrize("fn", [F.vertical_flip, transform_cls_to_functional(transforms.RandomVerticalFlip, p=1)])
15361618
def test_bounding_boxes_correctness(self, format, fn):
15371619
bounding_boxes = make_bounding_boxes(format=format)

test/test_tv_tensors.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,34 @@ def test_bbox_instance(data, format):
4343
assert bboxes.format == format
4444

4545

46+
@pytest.mark.parametrize(
47+
"format, is_rotated_expected",
48+
[
49+
("XYXY", False),
50+
("XYWH", False),
51+
("CXCYWH", False),
52+
("XYXYXYXY", True),
53+
("XYWHR", True),
54+
("CXCYWHR", True),
55+
(tv_tensors.BoundingBoxFormat.XYXY, False),
56+
(tv_tensors.BoundingBoxFormat.XYWH, False),
57+
(tv_tensors.BoundingBoxFormat.CXCYWH, False),
58+
(tv_tensors.BoundingBoxFormat.XYXYXYXY, True),
59+
(tv_tensors.BoundingBoxFormat.XYWHR, True),
60+
(tv_tensors.BoundingBoxFormat.CXCYWHR, True),
61+
],
62+
)
63+
@pytest.mark.parametrize("scripted", (False, True))
64+
def test_bbox_format(format, is_rotated_expected, scripted):
65+
if isinstance(format, str):
66+
format = tv_tensors.BoundingBoxFormat[(format.upper())]
67+
68+
fn = tv_tensors.is_rotated_bounding_format
69+
if scripted:
70+
fn = torch.jit.script(fn)
71+
assert fn(format) == is_rotated_expected
72+
73+
4674
def test_bbox_dim_error():
4775
data_3d = [[[1, 2, 3, 4]]]
4876
with pytest.raises(ValueError, match="Expected a 1D or 2D tensor, got 3D"):

test/test_utils.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,25 @@
1717
PILLOW_VERSION = tuple(int(x) for x in PILLOW_VERSION.split("."))
1818

1919
boxes = torch.tensor([[0, 0, 20, 20], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float)
20-
20+
rotated_boxes = torch.tensor(
21+
[
22+
[100, 150, 150, 150, 150, 250, 100, 250],
23+
[200, 350, 250, 350, 250, 250, 200, 250],
24+
[300, 200, 200, 200, 200, 250, 300, 250],
25+
# Not really a rectangle, but it doesn't matter
26+
[
27+
100,
28+
100,
29+
200,
30+
50,
31+
290,
32+
350,
33+
200,
34+
400,
35+
],
36+
],
37+
dtype=torch.float,
38+
)
2139
keypoints = torch.tensor([[[10, 10], [5, 5], [2, 2]], [[20, 20], [30, 30], [3, 3]]], dtype=torch.float)
2240

2341

@@ -148,6 +166,17 @@ def test_draw_boxes_with_coloured_label_backgrounds():
148166
assert_equal(result, expected)
149167

150168

169+
@pytest.mark.skipif(PILLOW_VERSION < (10, 1), reason="The reference image is only valid for PIL >= 10.1")
170+
def test_draw_rotated_boxes():
171+
img = torch.full((3, 500, 500), 255, dtype=torch.uint8)
172+
colors = ["blue", "yellow", (0, 255, 0), "black"]
173+
174+
result = utils.draw_bounding_boxes(img, rotated_boxes, colors=colors)
175+
path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets", "fakedata", "draw_rotated_boxes.png")
176+
expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
177+
assert_equal(result, expected)
178+
179+
151180
@pytest.mark.parametrize("fill", [True, False])
152181
def test_draw_boxes_dtypes(fill):
153182
img_uint8 = torch.full((3, 100, 100), 255, dtype=torch.uint8)

torchvision/ops/_box_convert.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -130,56 +130,56 @@ def _box_xywhr_to_cxcywhr(boxes: Tensor) -> Tensor:
130130

131131
def _box_xywhr_to_xyxyxyxy(boxes: Tensor) -> Tensor:
132132
"""
133-
Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x3, y3, x2, y2, x4, y4) format.
133+
Converts rotated bounding boxes from (x1, y1, w, h, r) format to (x1, y1, x2, y2, x3, y3, x4, y4) format.
134134
(x1, y1) refer to top left of bounding box
135135
(w, h) are width and height of the rotated bounding box
136136
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
137137
138138
(x1, y1) refer to top left of rotated bounding box
139-
(x3, y3) refer to top right of rotated bounding box
140-
(x2, y2) refer to bottom right of rotated bounding box
139+
(x2, y2) refer to top right of rotated bounding box
140+
(x3, y3) refer to bottom right of rotated bounding box
141141
(x4, y4) refer to bottom left ofrotated bounding box
142142
Args:
143143
boxes (Tensor[N, 5]): rotated boxes in (cx, cy, w, h, r) format which will be converted.
144144
145145
Returns:
146-
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format.
146+
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
147147
"""
148148
x1, y1, w, h, r = boxes.unbind(-1)
149149
r_rad = r * torch.pi / 180.0
150150
cos, sin = torch.cos(r_rad), torch.sin(r_rad)
151151

152-
x3 = x1 + w * cos
153-
y3 = y1 - w * sin
154-
x2 = x3 + h * sin
155-
y2 = y3 + h * cos
152+
x2 = x1 + w * cos
153+
y2 = y1 - w * sin
154+
x3 = x2 + h * sin
155+
y3 = y2 + h * cos
156156
x4 = x1 + h * sin
157157
y4 = y1 + h * cos
158158

159-
return torch.stack((x1, y1, x3, y3, x2, y2, x4, y4), dim=-1)
159+
return torch.stack((x1, y1, x2, y2, x3, y3, x4, y4), dim=-1)
160160

161161

162162
def _box_xyxyxyxy_to_xywhr(boxes: Tensor) -> Tensor:
163163
"""
164-
Converts rotated bounding boxes from (x1, y1, x3, y3, x2, y2, x4, y4) format to (x1, y1, w, h, r) format.
164+
Converts rotated bounding boxes from (x1, y1, x2, y2, x3, y3, x4, y4) format to (x1, y1, w, h, r) format.
165165
(x1, y1) refer to top left of the rotated bounding box
166-
(x3, y3) refer to bottom left of the rotated bounding box
167-
(x2, y2) refer to bottom right of the rotated bounding box
166+
(x2, y2) refer to bottom left of the rotated bounding box
167+
(x3, y3) refer to bottom right of the rotated bounding box
168168
(x4, y4) refer to top right of the rotated bounding box
169169
(w, h) refers to width and height of rotated bounding box
170170
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
171171
172172
Args:
173-
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x3, y3, x2, y2, x4, y4) format.
173+
boxes (Tensor(N, 8)): rotated boxes in (x1, y1, x2, y2, x3, y3, x4, y4) format.
174174
175175
Returns:
176176
boxes (Tensor[N, 5]): rotated boxes in (x1, y1, w, h, r) format.
177177
"""
178-
x1, y1, x3, y3, x2, y2, x4, y4 = boxes.unbind(-1)
179-
r_rad = torch.atan2(y1 - y3, x3 - x1)
178+
x1, y1, x2, y2, x3, y3, x4, y4 = boxes.unbind(-1)
179+
r_rad = torch.atan2(y1 - y2, x2 - x1)
180180
r = r_rad * 180 / torch.pi
181181

182-
w = ((x3 - x1) ** 2 + (y1 - y3) ** 2).sqrt()
182+
w = ((x2 - x1) ** 2 + (y1 - y2) ** 2).sqrt()
183183
h = ((x3 - x2) ** 2 + (y3 - y2) ** 2).sqrt()
184184

185185
boxes = torch.stack((x1, y1, w, h, r), dim=-1)

torchvision/ops/boxes.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,8 +209,8 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor:
209209
being width and height.
210210
r is rotation angle w.r.t to the box center by :math:`|r|` degrees counter clock wise in the image plan
211211
212-
``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 bottom right,
213-
x3, y3 bottom left, and x4, y4 top right.
212+
``'xyxyxyxy'``: boxes are represented via corners, x1, y1 being top left, x2, y2 top right,
213+
x3, y3 bottom right, and x4, y4 bottom left.
214214
215215
Args:
216216
boxes (Tensor[N, K]): boxes which will be converted. K is the number of coordinates (4 for unrotated bounding boxes, 5 or 8 for rotated bounding boxes)

0 commit comments

Comments
 (0)