Skip to content

Commit e1a66c2

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Cleanup conversion transforms (#6801)
Summary: * remove copy from convert_color_space * remove copy from convert_format_bounding_box * remove .to_* methods from features * remove unnecessary clones * add perf todos * refactor convert_color_space * lint * remove another clone * and another clone * remove a missed copy Reviewed By: YosuaMichael Differential Revision: D40722906 fbshipit-source-id: 3b757af1b69fcd85b085a5df9ff88cdaeafca130
1 parent 80ac92c commit e1a66c2

File tree

9 files changed

+92
-124
lines changed

9 files changed

+92
-124
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,7 @@ def transform(bbox):
461461
],
462462
dtype=bbox.dtype,
463463
)
464-
return F.convert_format_bounding_box(
465-
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
466-
)
464+
return F.convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=format)
467465

468466
if bounding_box.ndim < 2:
469467
bounding_box = [bounding_box]
@@ -556,26 +554,20 @@ def sample_inputs_affine_video():
556554

557555

558556
def sample_inputs_convert_format_bounding_box():
559-
formats = set(features.BoundingBoxFormat)
560-
for bounding_box_loader in make_bounding_box_loaders(formats=formats):
561-
old_format = bounding_box_loader.format
562-
for params in combinations_grid(new_format=formats - {old_format}, copy=(True, False)):
563-
yield ArgsKwargs(bounding_box_loader, old_format=old_format, **params)
564-
557+
formats = list(features.BoundingBoxFormat)
558+
for bounding_box_loader, new_format in itertools.product(make_bounding_box_loaders(formats=formats), formats):
559+
yield ArgsKwargs(bounding_box_loader, old_format=bounding_box_loader.format, new_format=new_format)
565560

566-
def reference_convert_format_bounding_box(bounding_box, old_format, new_format, copy):
567-
if not copy:
568-
raise pytest.UsageError("Reference for `convert_format_bounding_box` only supports `copy=True`")
569561

562+
def reference_convert_format_bounding_box(bounding_box, old_format, new_format):
570563
return torchvision.ops.box_convert(
571564
bounding_box, in_fmt=old_format.kernel_name.lower(), out_fmt=new_format.kernel_name.lower()
572565
)
573566

574567

575568
def reference_inputs_convert_format_bounding_box():
576569
for args_kwargs in sample_inputs_convert_color_space_image_tensor():
577-
(image_loader, *other_args), kwargs = args_kwargs
578-
if len(image_loader.shape) == 2 and kwargs.setdefault("copy", True):
570+
if len(args_kwargs.args[0].shape) == 2:
579571
yield args_kwargs
580572

581573

@@ -600,19 +592,19 @@ def sample_inputs_convert_color_space_image_tensor():
600592
for image_loader in make_image_loaders(
601593
sizes=["random"], color_spaces=[color_space], dtypes=[torch.float32], constant_alpha=True
602594
):
603-
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space, copy=False)
595+
yield ArgsKwargs(image_loader, old_color_space=color_space, new_color_space=color_space)
604596

605597

606598
@pil_reference_wrapper
607-
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space, copy=True):
599+
def reference_convert_color_space_image_tensor(image_pil, old_color_space, new_color_space):
608600
color_space_pil = features.ColorSpace.from_pil_mode(image_pil.mode)
609601
if color_space_pil != old_color_space:
610602
raise pytest.UsageError(
611603
f"Converting the tensor image into an PIL image changed the colorspace "
612604
f"from {old_color_space} to {color_space_pil}"
613605
)
614606

615-
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space, copy=copy)
607+
return F.convert_color_space_image_pil(image_pil, color_space=new_color_space)
616608

617609

618610
def reference_inputs_convert_color_space_image_tensor():

test/test_prototype_transforms_functional.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -478,9 +478,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
478478
device=bbox.device,
479479
)
480480
return (
481-
convert_format_bounding_box(
482-
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
483-
),
481+
convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format),
484482
(height, width),
485483
)
486484

@@ -733,14 +731,16 @@ def _compute_expected_bbox(bbox, padding_):
733731

734732
bbox_format = bbox.format
735733
bbox_dtype = bbox.dtype
736-
bbox = convert_format_bounding_box(bbox, old_format=bbox_format, new_format=features.BoundingBoxFormat.XYXY)
734+
bbox = (
735+
bbox.clone()
736+
if bbox_format == features.BoundingBoxFormat.XYXY
737+
else convert_format_bounding_box(bbox, bbox_format, features.BoundingBoxFormat.XYXY)
738+
)
737739

738740
bbox[0::2] += pad_left
739741
bbox[1::2] += pad_up
740742

741-
bbox = convert_format_bounding_box(
742-
bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
743-
)
743+
bbox = convert_format_bounding_box(bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format)
744744
if bbox.dtype != bbox_dtype:
745745
# Temporary cast to original dtype
746746
# e.g. float32 -> int
@@ -840,9 +840,7 @@ def _compute_expected_bbox(bbox, pcoeffs_):
840840
dtype=bbox.dtype,
841841
device=bbox.device,
842842
)
843-
return convert_format_bounding_box(
844-
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
845-
)
843+
return convert_format_bounding_box(out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format)
846844

847845
spatial_size = (32, 38)
848846

@@ -903,7 +901,7 @@ def _compute_expected_bbox(bbox, output_size_):
903901
dtype=bbox.dtype,
904902
device=bbox.device,
905903
)
906-
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_, copy=False)
904+
return convert_format_bounding_box(out_bbox, features.BoundingBoxFormat.XYWH, format_)
907905

908906
for bboxes in make_bounding_boxes(extra_dims=((4,),)):
909907
bboxes = bboxes.to(device)

torchvision/prototype/features/_image.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -110,18 +110,6 @@ def spatial_size(self) -> Tuple[int, int]:
110110
def num_channels(self) -> int:
111111
return self.shape[-3]
112112

113-
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Image:
114-
if isinstance(color_space, str):
115-
color_space = ColorSpace.from_str(color_space.upper())
116-
117-
return Image.wrap_like(
118-
self,
119-
self._F.convert_color_space_image_tensor(
120-
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
121-
),
122-
color_space=color_space,
123-
)
124-
125113
def horizontal_flip(self) -> Image:
126114
output = self._F.horizontal_flip_image_tensor(self.as_subclass(torch.Tensor))
127115
return Image.wrap_like(self, output)

torchvision/prototype/features/_video.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,6 @@ def num_channels(self) -> int:
6666
def num_frames(self) -> int:
6767
return self.shape[-4]
6868

69-
def to_color_space(self, color_space: Union[str, ColorSpace], copy: bool = True) -> Video:
70-
if isinstance(color_space, str):
71-
color_space = ColorSpace.from_str(color_space.upper())
72-
73-
return Video.wrap_like(
74-
self,
75-
self._F.convert_color_space_video(
76-
self.as_subclass(torch.Tensor), old_color_space=self.color_space, new_color_space=color_space, copy=copy
77-
),
78-
color_space=color_space,
79-
)
80-
8169
def horizontal_flip(self) -> Video:
8270
output = self._F.horizontal_flip_video(self.as_subclass(torch.Tensor))
8371
return Video.wrap_like(self, output)

torchvision/prototype/transforms/_augment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ def _copy_paste(
265265
# https://github.com/pytorch/vision/blob/b6feccbc4387766b76a3e22b13815dbbbfa87c0f/torchvision/models/detection/roi_heads.py#L418-L422
266266
xyxy_boxes[:, 2:] += 1
267267
boxes = F.convert_format_bounding_box(
268-
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format, copy=False
268+
xyxy_boxes, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox_format
269269
)
270270
out_target["boxes"] = torch.cat([boxes, paste_boxes])
271271

torchvision/prototype/transforms/_geometry.py

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -655,9 +655,7 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
655655
continue
656656

657657
# check for any valid boxes with centers within the crop area
658-
xyxy_bboxes = F.convert_format_bounding_box(
659-
bboxes, old_format=bboxes.format, new_format=features.BoundingBoxFormat.XYXY, copy=True
660-
)
658+
xyxy_bboxes = F.convert_format_bounding_box(bboxes, bboxes.format, features.BoundingBoxFormat.XYXY)
661659
cx = 0.5 * (xyxy_bboxes[..., 0] + xyxy_bboxes[..., 2])
662660
cy = 0.5 * (xyxy_bboxes[..., 1] + xyxy_bboxes[..., 3])
663661
is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
@@ -801,22 +799,21 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
801799
top = int(offset_height * r)
802800
left = int(offset_width * r)
803801

802+
bounding_boxes: Optional[torch.Tensor]
804803
try:
805804
bounding_boxes = query_bounding_box(flat_inputs)
806805
except ValueError:
807806
bounding_boxes = None
808807

809808
if needs_crop and bounding_boxes is not None:
810-
bounding_boxes = cast(
811-
features.BoundingBox, F.crop(bounding_boxes, top=top, left=left, height=new_height, width=new_width)
812-
)
813-
bounding_boxes = features.BoundingBox.wrap_like(
814-
bounding_boxes,
815-
F.clamp_bounding_box(
816-
bounding_boxes, format=bounding_boxes.format, spatial_size=bounding_boxes.spatial_size
817-
),
809+
format = bounding_boxes.format
810+
bounding_boxes, spatial_size = F.crop_bounding_box(
811+
bounding_boxes, format=format, top=top, left=left, height=new_height, width=new_width
818812
)
819-
height_and_width = bounding_boxes.to_format(features.BoundingBoxFormat.XYWH)[..., 2:]
813+
bounding_boxes = F.clamp_bounding_box(bounding_boxes, format=format, spatial_size=spatial_size)
814+
height_and_width = F.convert_format_bounding_box(
815+
bounding_boxes, old_format=format, new_format=features.BoundingBoxFormat.XYWH
816+
)[..., 2:]
820817
is_valid = torch.all(height_and_width > 0, dim=-1)
821818
else:
822819
is_valid = None

torchvision/prototype/transforms/_meta.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@ def __init__(
5050
self,
5151
color_space: Union[str, features.ColorSpace],
5252
old_color_space: Optional[Union[str, features.ColorSpace]] = None,
53-
copy: bool = True,
5453
) -> None:
5554
super().__init__()
5655

@@ -62,14 +61,10 @@ def __init__(
6261
old_color_space = features.ColorSpace.from_str(old_color_space)
6362
self.old_color_space = old_color_space
6463

65-
self.copy = copy
66-
6764
def _transform(
6865
self, inpt: Union[features.ImageType, features.VideoType], params: Dict[str, Any]
6966
) -> Union[features.ImageType, features.VideoType]:
70-
return F.convert_color_space(
71-
inpt, color_space=self.color_space, old_color_space=self.old_color_space, copy=self.copy
72-
)
67+
return F.convert_color_space(inpt, color_space=self.color_space, old_color_space=self.old_color_space)
7368

7469

7570
class ClampBoundingBoxes(Transform):

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,18 @@ def horizontal_flip_bounding_box(
3636
) -> torch.Tensor:
3737
shape = bounding_box.shape
3838

39-
bounding_box = convert_format_bounding_box(
40-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
39+
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
40+
# BoundingBoxFormat instead of converting back and forth
41+
bounding_box = (
42+
bounding_box.clone()
43+
if format == features.BoundingBoxFormat.XYXY
44+
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
4145
).reshape(-1, 4)
4246

4347
bounding_box[:, [0, 2]] = spatial_size[1] - bounding_box[:, [2, 0]]
4448

4549
return convert_format_bounding_box(
46-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
50+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
4751
).reshape(shape)
4852

4953

@@ -73,14 +77,18 @@ def vertical_flip_bounding_box(
7377
) -> torch.Tensor:
7478
shape = bounding_box.shape
7579

76-
bounding_box = convert_format_bounding_box(
77-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
80+
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
81+
# BoundingBoxFormat instead of converting back and forth
82+
bounding_box = (
83+
bounding_box.clone()
84+
if format == features.BoundingBoxFormat.XYXY
85+
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
7886
).reshape(-1, 4)
7987

8088
bounding_box[:, [1, 3]] = spatial_size[0] - bounding_box[:, [3, 1]]
8189

8290
return convert_format_bounding_box(
83-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
91+
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format
8492
).reshape(shape)
8593

8694

@@ -394,16 +402,17 @@ def affine_bounding_box(
394402
center: Optional[List[float]] = None,
395403
) -> torch.Tensor:
396404
original_shape = bounding_box.shape
397-
bounding_box = convert_format_bounding_box(
398-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
405+
406+
bounding_box = (
407+
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
399408
).reshape(-1, 4)
400409

401410
out_bboxes, _ = _affine_bounding_box_xyxy(bounding_box, spatial_size, angle, translate, scale, shear, center)
402411

403412
# out_bboxes should be of shape [N boxes, 4]
404413

405414
return convert_format_bounding_box(
406-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
415+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
407416
).reshape(original_shape)
408417

409418

@@ -583,8 +592,8 @@ def rotate_bounding_box(
583592
center = None
584593

585594
original_shape = bounding_box.shape
586-
bounding_box = convert_format_bounding_box(
587-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
595+
bounding_box = (
596+
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
588597
).reshape(-1, 4)
589598

590599
out_bboxes, spatial_size = _affine_bounding_box_xyxy(
@@ -599,9 +608,9 @@ def rotate_bounding_box(
599608
)
600609

601610
return (
602-
convert_format_bounding_box(
603-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
604-
).reshape(original_shape),
611+
convert_format_bounding_box(out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format).reshape(
612+
original_shape
613+
),
605614
spatial_size,
606615
)
607616

@@ -818,18 +827,20 @@ def crop_bounding_box(
818827
height: int,
819828
width: int,
820829
) -> Tuple[torch.Tensor, Tuple[int, int]]:
821-
bounding_box = convert_format_bounding_box(
822-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
830+
# TODO: Investigate if it makes sense from a performance perspective to have an implementation for every
831+
# BoundingBoxFormat instead of converting back and forth
832+
bounding_box = (
833+
bounding_box.clone()
834+
if format == features.BoundingBoxFormat.XYXY
835+
else convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
823836
)
824837

825838
# Crop or implicit pad if left and/or top have negative values:
826839
bounding_box[..., 0::2] -= left
827840
bounding_box[..., 1::2] -= top
828841

829842
return (
830-
convert_format_bounding_box(
831-
bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
832-
),
843+
convert_format_bounding_box(bounding_box, old_format=features.BoundingBoxFormat.XYXY, new_format=format),
833844
(height, width),
834845
)
835846

@@ -896,8 +907,8 @@ def perspective_bounding_box(
896907
raise ValueError("Argument perspective_coeffs should have 8 float values")
897908

898909
original_shape = bounding_box.shape
899-
bounding_box = convert_format_bounding_box(
900-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
910+
bounding_box = (
911+
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
901912
).reshape(-1, 4)
902913

903914
dtype = bounding_box.dtype if torch.is_floating_point(bounding_box) else torch.float32
@@ -967,7 +978,7 @@ def perspective_bounding_box(
967978
# out_bboxes should be of shape [N boxes, 4]
968979

969980
return convert_format_bounding_box(
970-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
981+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
971982
).reshape(original_shape)
972983

973984

@@ -1061,8 +1072,8 @@ def elastic_bounding_box(
10611072
displacement = displacement.to(bounding_box.device)
10621073

10631074
original_shape = bounding_box.shape
1064-
bounding_box = convert_format_bounding_box(
1065-
bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY
1075+
bounding_box = (
1076+
convert_format_bounding_box(bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY)
10661077
).reshape(-1, 4)
10671078

10681079
# Question (vfdev-5): should we rely on good displacement shape and fetch image size from it
@@ -1088,7 +1099,7 @@ def elastic_bounding_box(
10881099
out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1).to(bounding_box.dtype)
10891100

10901101
return convert_format_bounding_box(
1091-
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False
1102+
out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format
10921103
).reshape(original_shape)
10931104

10941105

0 commit comments

Comments
 (0)