Skip to content

Commit 70faba9

Browse files
authored
[prototype] Add support of inplace on convert_format_bounding_box (#6858)
* Add support of inplace on `convert_format_bounding_box` * Move `as_subclass` calls to `F` invocations * Fix bug. * Fix _cxcywh_to_xyxy. * Fixing _xyxy_to_cxcywh. * Adding comments.
1 parent cba1c01 commit 70faba9

File tree

5 files changed

+65
-53
lines changed

5 files changed

+65
-53
lines changed

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _copy_paste(
262262
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
263263
xyxy_boxes[:, 2:] += 1
264264
boxes = F.convert_format_bounding_box(
265-
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format
265+
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, inplace=True
266266
)
267267
out_target["boxes"] = torch.cat([boxes, paste_boxes])
268268

torchvision/prototype/transforms/_geometry.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -646,7 +646,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
646646
continue
647647

648648
# check for any valid boxes with centers within the crop area
649-
xyxy_bboxes = F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY)
649+
xyxy_bboxes = F.convert_format_bounding_box(
650+
bboxes.as_subclass(torch.Tensor), bboxes.format, features.BoundingBoxFormat.XYXY
651+
)
650652
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
651653
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
652654
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
@@ -799,7 +801,12 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
799801
if needs_crop and bounding_boxes is not None:
800802
format = bounding_boxes.format
801803
bounding_boxes, spatial_size = F.crop_bounding_box(
802-
bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width
804+
bounding_boxes.as_subclass(torch.Tensor),
805+
format=format,
806+
top=top,
807+
left=left,
808+
height=new_height,
809+
width=new_width,
803810
)
804811
bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
805812
height_and_width = F.convert_format_bounding_box(

torchvision/prototype/transforms/_misc.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,9 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
207207
# format,we need to convert first just to afterwards compute the width and height again, although they were
208208
# there in the first place for these formats.
209209
bounding_box = F.convert_format_bounding_box(
210-
bounding_box, old_format=bounding_box.format, new_format=features.BoundingBoxFormat.XYXY
210+
bounding_box.as_subclass(torch.Tensor),
211+
old_format=bounding_box.format,
212+
new_format=features.BoundingBoxFormat.XYXY,
211213
)
212214
valid_indices = remove_small_boxes(bounding_box, min_size=self.min_size)
213215

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,14 @@ def horizontal_flip_bounding_box(
3838

3939
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
4040
# BoundingBoxFormat instead of converting back and forth
41-
bounding_box = (
42-
bounding_box.clone()
43-
if format == features.BoundingBoxFormat.XYXY
44-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
41+
bounding_box = convert_format_bounding_box(
42+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
4543
).reshape(-1, 4)
4644

4745
bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
4846

4947
return convert_format_bounding_box(
50-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
48+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
5149
).reshape(shape)
5250

5351

@@ -79,16 +77,14 @@ def vertical_flip_bounding_box(
7977

8078
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
8179
# BoundingBoxFormat instead of converting back and forth
82-
bounding_box = (
83-
bounding_box.clone()
84-
if format == features.BoundingBoxFormat.XYXY
85-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
80+
bounding_box = convert_format_bounding_box(
81+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
8682
).reshape(-1, 4)
8783

8884
bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
8985

9086
return convert_format_bounding_box(
91-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
87+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
9288
).reshape(shape)
9389

9490

@@ -412,7 +408,7 @@ def affine_bounding_box(
412408
# out_bboxes should be of shape [N boxes, 4]
413409

414410
return convert_format_bounding_box(
415-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
411+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
416412
).reshape(original_shape)
417413

418414

@@ -594,9 +590,9 @@ def rotate_bounding_box(
594590
)
595591

596592
return (
597-
convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape(
598-
original_shape
599-
),
593+
convert_format_bounding_box(
594+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
595+
).reshape(original_shape),
600596
spatial_size,
601597
)
602598

@@ -815,18 +811,18 @@ def crop_bounding_box(
815811
) -> Tuple[torch.Tensor, Tuple[int, int]]:
816812
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
817813
# BoundingBoxFormat instead of converting back and forth
818-
bounding_box = (
819-
bounding_box.clone()
820-
if format == features.BoundingBoxFormat.XYXY
821-
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
814+
bounding_box = convert_format_bounding_box(
815+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
822816
)
823817

824818
# Crop or implicit pad if left and/or top have negative values:
825819
bounding_box[..., 0::2] -= left
826820
bounding_box[..., 1::2] -= top
827821

828822
return (
829-
convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format),
823+
convert_format_bounding_box(
824+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
825+
),
830826
(height, width),
831827
)
832828

@@ -964,7 +960,7 @@ def perspective_bounding_box(
964960
# out_bboxes should be of shape [N boxes, 4]
965961

966962
return convert_format_bounding_box(
967-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
963+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
968964
).reshape(original_shape)
969965

970966

@@ -1085,7 +1081,7 @@ def elastic_bounding_box(
10851081
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
10861082

10871083
return convert_format_bounding_box(
1088-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
1084+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, inplace=True
10891085
).reshape(original_shape)
10901086

10911087

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 35 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -119,51 +119,60 @@ def get_num_frames(inpt: features.VideoTypeJIT) -> int:
119119
raise TypeError(f"The video should be a Tensor. Got {type(inpt)}")
120120

121121

122-
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
123-
xyxy = xywh.clone()
122+
def _xywh_to_xyxy(xywh: torch.Tensor, inplace: bool) -> torch.Tensor:
123+
xyxy = xywh if inplace else xywh.clone()
124124
xyxy[..., 2:] += xyxy[..., :2]
125125
return xyxy
126126

127127

128-
def _xyxy_to_xywh(xyxy: torch.Tensor) -> torch.Tensor:
129-
xywh = xyxy.clone()
128+
def _xyxy_to_xywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
129+
xywh = xyxy if inplace else xyxy.clone()
130130
xywh[..., 2:] -= xywh[..., :2]
131131
return xywh
132132

133133

134-
def _cxcywh_to_xyxy(cxcywh: torch.Tensor) -> torch.Tensor:
135-
cx, cy, w, h = torch.unbind(cxcywh, dim=-1)
136-
x1 = cx - 0.5 * w
137-
y1 = cy - 0.5 * h
138-
x2 = cx + 0.5 * w
139-
y2 = cy + 0.5 * h
140-
return torch.stack((x1, y1, x2, y2), dim=-1).to(cxcywh.dtype)
134+
def _cxcywh_to_xyxy(cxcywh: torch.Tensor, inplace: bool) -> torch.Tensor:
135+
if not inplace:
136+
cxcywh = cxcywh.clone()
141137

138+
# Trick to do fast division by 2 and ceil, without casting. It produces the same result as
139+
# `torchvision.ops._box_convert._box_cxcywh_to_xyxy`.
140+
half_wh = cxcywh[..., 2:].div(-2, rounding_mode=None if cxcywh.is_floating_point() else "floor").abs_()
141+
# (cx - width / 2) = x1, same for y1
142+
cxcywh[..., :2].sub_(half_wh)
143+
# (x1 + width) = x2, same for y2
144+
cxcywh[..., 2:].add_(cxcywh[..., :2])
142145

143-
def _xyxy_to_cxcywh(xyxy: torch.Tensor) -> torch.Tensor:
144-
x1, y1, x2, y2 = torch.unbind(xyxy, dim=-1)
145-
cx = (x1 + x2) / 2
146-
cy = (y1 + y2) / 2
147-
w = x2 - x1
148-
h = y2 - y1
149-
return torch.stack((cx, cy, w, h), dim=-1).to(xyxy.dtype)
146+
return cxcywh
147+
148+
149+
def _xyxy_to_cxcywh(xyxy: torch.Tensor, inplace: bool) -> torch.Tensor:
150+
if not inplace:
151+
xyxy = xyxy.clone()
152+
153+
# (x2 - x1) = width, same for height
154+
xyxy[..., 2:].sub_(xyxy[..., :2])
155+
# (x1 * 2 + width) / 2 = x1 + width / 2 = x1 + (x2-x1)/2 = (x1 + x2)/2 = cx, same for cy
156+
xyxy[..., :2].mul_(2).add_(xyxy[..., 2:]).div_(2, rounding_mode=None if xyxy.is_floating_point() else "floor")
157+
158+
return xyxy
150159

151160

152161
def convert_format_bounding_box(
153-
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat
162+
bounding_box: torch.Tensor, old_format: BoundingBoxFormat, new_format: BoundingBoxFormat, inplace: bool = False
154163
) -> torch.Tensor:
155164
if new_format == old_format:
156165
return bounding_box
157166

158167
if old_format == BoundingBoxFormat.XYWH:
159-
bounding_box = _xywh_to_xyxy(bounding_box)
168+
bounding_box = _xywh_to_xyxy(bounding_box, inplace)
160169
elif old_format == BoundingBoxFormat.CXCYWH:
161-
bounding_box = _cxcywh_to_xyxy(bounding_box)
170+
bounding_box = _cxcywh_to_xyxy(bounding_box, inplace)
162171

163172
if new_format == BoundingBoxFormat.XYWH:
164-
bounding_box = _xyxy_to_xywh(bounding_box)
173+
bounding_box = _xyxy_to_xywh(bounding_box, inplace)
165174
elif new_format == BoundingBoxFormat.CXCYWH:
166-
bounding_box = _xyxy_to_cxcywh(bounding_box)
175+
bounding_box = _xyxy_to_cxcywh(bounding_box, inplace)
167176

168177
return bounding_box
169178

@@ -173,14 +182,12 @@ def clamp_bounding_box(
173182
) -> torch.Tensor:
174183
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
175184
# BoundingBoxFormat instead of converting back and forth
176-
xyxy_boxes = (
177-
bounding_box.clone()
178-
if format == BoundingBoxFormat.XYXY
179-
else convert_format_bounding_box(bounding_box, format, BoundingBoxFormat.XYXY)
185+
xyxy_boxes = convert_format_bounding_box(
186+
bounding_box.clone(), old_format=format, new_format=features.BoundingBoxFormat.XYXY, inplace=True
180187
)
181188
xyxy_boxes[..., 0::2].clamp_(min=0, max=spatial_size[1])
182189
xyxy_boxes[..., 1::2].clamp_(min=0, max=spatial_size[0])
183-
return convert_format_bounding_box(xyxy_boxes, BoundingBoxFormat.XYXY, format)
190+
return convert_format_bounding_box(xyxy_boxes, old_format=BoundingBoxFormat.XYXY, new_format=format, inplace=True)
184191

185192

186193
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)