Skip to content

Commit ae83c9f

Browse files
authored
[PoC] move metadata computation from prototype features into kernels (#6646)
* move metadata computation from prototype features into kernels * fix tests * fix no_inplace test * mypy * add perf TODO
1 parent 2907c49 commit ae83c9f

File tree

4 files changed

+130
-105
lines changed

4 files changed

+130
-105
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -709,7 +709,7 @@ def sample_inputs_crop_bounding_box():
709709
for bounding_box_loader, params in itertools.product(
710710
make_bounding_box_loaders(), [_CROP_PARAMS[0], _CROP_PARAMS[-1]]
711711
):
712-
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, top=params["top"], left=params["left"])
712+
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
713713

714714

715715
def sample_inputs_crop_mask():
@@ -856,7 +856,9 @@ def sample_inputs_pad_bounding_box():
856856
if params["padding_mode"] != "constant":
857857
continue
858858

859-
yield ArgsKwargs(bounding_box_loader, format=bounding_box_loader.format, **params)
859+
yield ArgsKwargs(
860+
bounding_box_loader, format=bounding_box_loader.format, image_size=bounding_box_loader.image_size, **params
861+
)
860862

861863

862864
def sample_inputs_pad_mask():
@@ -1552,8 +1554,6 @@ def reference_inputs_ten_crop_image_tensor():
15521554
skips=[
15531555
skip_integer_size_jit(),
15541556
Skip("test_batched_vs_single", reason="Custom batching needed for five_crop_image_tensor."),
1555-
Skip("test_no_inplace", reason="Output of five_crop_image_tensor is not a tensor."),
1556-
Skip("test_dtype_and_device_consistency", reason="Output of five_crop_image_tensor is not a tensor."),
15571557
],
15581558
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
15591559
),
@@ -1565,8 +1565,6 @@ def reference_inputs_ten_crop_image_tensor():
15651565
skips=[
15661566
skip_integer_size_jit(),
15671567
Skip("test_batched_vs_single", reason="Custom batching needed for ten_crop_image_tensor."),
1568-
Skip("test_no_inplace", reason="Output of ten_crop_image_tensor is not a tensor."),
1569-
Skip("test_dtype_and_device_consistency", reason="Output of ten_crop_image_tensor is not a tensor."),
15701568
],
15711569
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
15721570
),

test/test_prototype_transforms_functional.py

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -68,17 +68,22 @@ def test_scripted_vs_eager(self, info, args_kwargs, device):
6868

6969
assert_close(actual, expected, **info.closeness_kwargs)
7070

71-
def _unbind_batch_dims(self, batched_tensor, *, data_dims):
72-
if batched_tensor.ndim == data_dims:
73-
return batched_tensor
74-
75-
return [self._unbind_batch_dims(t, data_dims=data_dims) for t in batched_tensor.unbind(0)]
71+
def _unbatch(self, batch, *, data_dims):
72+
if isinstance(batch, torch.Tensor):
73+
batched_tensor = batch
74+
metadata = ()
75+
else:
76+
batched_tensor, *metadata = batch
7677

77-
def _stack_batch_dims(self, unbound_tensor):
78-
if isinstance(unbound_tensor[0], torch.Tensor):
79-
return torch.stack(unbound_tensor)
78+
if batched_tensor.ndim == data_dims:
79+
return batch
8080

81-
return torch.stack([self._stack_batch_dims(t) for t in unbound_tensor])
81+
return [
82+
self._unbatch(unbatched, data_dims=data_dims)
83+
for unbatched in (
84+
batched_tensor.unbind(0) if not metadata else [(t, *metadata) for t in batched_tensor.unbind(0)]
85+
)
86+
]
8287

8388
@sample_inputs
8489
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -106,11 +111,11 @@ def test_batched_vs_single(self, info, args_kwargs, device):
106111
elif not all(batched_input.shape[:-data_dims]):
107112
pytest.skip("Input has a degenerate batch shape.")
108113

109-
actual = info.kernel(batched_input, *other_args, **kwargs)
114+
batched_output = info.kernel(batched_input, *other_args, **kwargs)
115+
actual = self._unbatch(batched_output, data_dims=data_dims)
110116

111-
single_inputs = self._unbind_batch_dims(batched_input, data_dims=data_dims)
112-
single_outputs = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
113-
expected = self._stack_batch_dims(single_outputs)
117+
single_inputs = self._unbatch(batched_input, data_dims=data_dims)
118+
expected = tree_map(lambda single_input: info.kernel(single_input, *other_args, **kwargs), single_inputs)
114119

115120
assert_close(actual, expected, **info.closeness_kwargs)
116121

@@ -123,9 +128,9 @@ def test_no_inplace(self, info, args_kwargs, device):
123128
pytest.skip("The input has a degenerate shape.")
124129

125130
input_version = input._version
126-
output = info.kernel(input, *other_args, **kwargs)
131+
info.kernel(input, *other_args, **kwargs)
127132

128-
assert output is not input or output._version == input_version
133+
assert input._version == input_version
129134

130135
@sample_inputs
131136
@needs_cuda
@@ -144,6 +149,9 @@ def test_dtype_and_device_consistency(self, info, args_kwargs, device):
144149
(input, *other_args), kwargs = args_kwargs.load(device)
145150

146151
output = info.kernel(input, *other_args, **kwargs)
152+
# Most kernels just return a tensor, but some also return some additional metadata
153+
if not isinstance(output, torch.Tensor):
154+
output, *_ = output
147155

148156
assert output.dtype == input.dtype
149157
assert output.device == input.device
@@ -324,7 +332,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
324332
affine_matrix = _compute_affine_matrix(angle_, [0.0, 0.0], 1.0, [0.0, 0.0], center_)
325333
affine_matrix = affine_matrix[:2, :]
326334

327-
image_size = bbox.image_size
335+
height, width = bbox.image_size
328336
bbox_xyxy = convert_format_bounding_box(
329337
bbox, old_format=bbox.format, new_format=features.BoundingBoxFormat.XYXY
330338
)
@@ -336,9 +344,9 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
336344
[bbox_xyxy[2].item(), bbox_xyxy[3].item(), 1.0],
337345
# image frame
338346
[0.0, 0.0, 1.0],
339-
[0.0, image_size[0], 1.0],
340-
[image_size[1], image_size[0], 1.0],
341-
[image_size[1], 0.0, 1.0],
347+
[0.0, height, 1.0],
348+
[width, height, 1.0],
349+
[width, 0.0, 1.0],
342350
]
343351
)
344352
transformed_points = np.matmul(points, affine_matrix.T)
@@ -356,18 +364,21 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
356364
out_bbox[2] -= tr_x
357365
out_bbox[3] -= tr_y
358366

359-
# image_size should be updated, but it is OK here to skip its computation
360-
# as we do not compute it in F.rotate_bounding_box
367+
height = int(height - 2 * tr_y)
368+
width = int(width - 2 * tr_x)
361369

362370
out_bbox = features.BoundingBox(
363371
out_bbox,
364372
format=features.BoundingBoxFormat.XYXY,
365-
image_size=image_size,
373+
image_size=(height, width),
366374
dtype=bbox.dtype,
367375
device=bbox.device,
368376
)
369-
return convert_format_bounding_box(
370-
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
377+
return (
378+
convert_format_bounding_box(
379+
out_bbox, old_format=features.BoundingBoxFormat.XYXY, new_format=bbox.format, copy=False
380+
),
381+
(height, width),
371382
)
372383

373384
image_size = (32, 38)
@@ -376,7 +387,7 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
376387
bboxes_format = bboxes.format
377388
bboxes_image_size = bboxes.image_size
378389

379-
output_bboxes = F.rotate_bounding_box(
390+
output_bboxes, output_image_size = F.rotate_bounding_box(
380391
bboxes,
381392
bboxes_format,
382393
image_size=bboxes_image_size,
@@ -395,12 +406,14 @@ def _compute_expected_bbox(bbox, angle_, expand_, center_):
395406
expected_bboxes = []
396407
for bbox in bboxes:
397408
bbox = features.BoundingBox(bbox, format=bboxes_format, image_size=bboxes_image_size)
398-
expected_bboxes.append(_compute_expected_bbox(bbox, -angle, expand, center_))
409+
expected_bbox, expected_image_size = _compute_expected_bbox(bbox, -angle, expand, center_)
410+
expected_bboxes.append(expected_bbox)
399411
if len(expected_bboxes) > 1:
400412
expected_bboxes = torch.stack(expected_bboxes)
401413
else:
402414
expected_bboxes = expected_bboxes[0]
403415
torch.testing.assert_close(output_bboxes, expected_bboxes, atol=1, rtol=0)
416+
torch.testing.assert_close(output_image_size, expected_image_size, atol=1, rtol=0)
404417

405418

406419
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -445,7 +458,7 @@ def test_correctness_rotate_bounding_box_on_fixed_input(device, expand):
445458
[18.36396103, 1.07968978, 46.64823228, 29.36396103],
446459
]
447460

448-
output_boxes = F.rotate_bounding_box(
461+
output_boxes, _ = F.rotate_bounding_box(
449462
in_boxes,
450463
in_boxes.format,
451464
in_boxes.image_size,
@@ -510,17 +523,20 @@ def test_correctness_crop_bounding_box(device, format, top, left, height, width,
510523
if format != features.BoundingBoxFormat.XYXY:
511524
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
512525

513-
output_boxes = F.crop_bounding_box(
526+
output_boxes, output_image_size = F.crop_bounding_box(
514527
in_boxes,
515528
format,
516529
top,
517530
left,
531+
size[0],
532+
size[1],
518533
)
519534

520535
if format != features.BoundingBoxFormat.XYXY:
521536
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
522537

523538
torch.testing.assert_close(output_boxes.tolist(), expected_bboxes)
539+
torch.testing.assert_close(output_image_size, size)
524540

525541

526542
@pytest.mark.parametrize("device", cpu_and_gpu())
@@ -585,12 +601,13 @@ def _compute_expected_bbox(bbox, top_, left_, height_, width_, size_):
585601
if format != features.BoundingBoxFormat.XYXY:
586602
in_boxes = convert_format_bounding_box(in_boxes, features.BoundingBoxFormat.XYXY, format)
587603

588-
output_boxes = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
604+
output_boxes, output_image_size = F.resized_crop_bounding_box(in_boxes, format, top, left, height, width, size)
589605

590606
if format != features.BoundingBoxFormat.XYXY:
591607
output_boxes = convert_format_bounding_box(output_boxes, format, features.BoundingBoxFormat.XYXY)
592608

593609
torch.testing.assert_close(output_boxes, expected_bboxes)
610+
torch.testing.assert_close(output_image_size, size)
594611

595612

596613
def _parse_padding(padding):
@@ -627,12 +644,21 @@ def _compute_expected_bbox(bbox, padding_):
627644
bbox = bbox.to(bbox_dtype)
628645
return bbox
629646

647+
def _compute_expected_image_size(bbox, padding_):
648+
pad_left, pad_up, pad_right, pad_down = _parse_padding(padding_)
649+
height, width = bbox.image_size
650+
return height + pad_up + pad_down, width + pad_left + pad_right
651+
630652
for bboxes in make_bounding_boxes():
631653
bboxes = bboxes.to(device)
632654
bboxes_format = bboxes.format
633655
bboxes_image_size = bboxes.image_size
634656

635-
output_boxes = F.pad_bounding_box(bboxes, format=bboxes_format, padding=padding)
657+
output_boxes, output_image_size = F.pad_bounding_box(
658+
bboxes, format=bboxes_format, image_size=bboxes_image_size, padding=padding
659+
)
660+
661+
torch.testing.assert_close(output_image_size, _compute_expected_image_size(bboxes, padding))
636662

637663
if bboxes.ndim < 2 or bboxes.shape[0] == 0:
638664
bboxes = [bboxes]
@@ -781,7 +807,9 @@ def _compute_expected_bbox(bbox, output_size_):
781807
bboxes_format = bboxes.format
782808
bboxes_image_size = bboxes.image_size
783809

784-
output_boxes = F.center_crop_bounding_box(bboxes, bboxes_format, bboxes_image_size, output_size)
810+
output_boxes, output_image_size = F.center_crop_bounding_box(
811+
bboxes, bboxes_format, bboxes_image_size, output_size
812+
)
785813

786814
if bboxes.ndim < 2:
787815
bboxes = [bboxes]
@@ -796,6 +824,7 @@ def _compute_expected_bbox(bbox, output_size_):
796824
else:
797825
expected_bboxes = expected_bboxes[0]
798826
torch.testing.assert_close(output_boxes, expected_bboxes)
827+
torch.testing.assert_close(output_image_size, output_size)
799828

800829

801830
@pytest.mark.parametrize("device", cpu_and_gpu())

torchvision/prototype/features/_bounding_box.py

Lines changed: 15 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -83,23 +83,19 @@ def resize( # type: ignore[override]
8383
max_size: Optional[int] = None,
8484
antialias: bool = False,
8585
) -> BoundingBox:
86-
output = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
87-
if isinstance(size, int):
88-
size = [size]
89-
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
90-
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
86+
output, image_size = self._F.resize_bounding_box(self, image_size=self.image_size, size=size, max_size=max_size)
87+
return BoundingBox.new_like(self, output, image_size=image_size)
9188

9289
def crop(self, top: int, left: int, height: int, width: int) -> BoundingBox:
93-
output = self._F.crop_bounding_box(self, self.format, top, left)
94-
return BoundingBox.new_like(self, output, image_size=(height, width))
90+
output, image_size = self._F.crop_bounding_box(
91+
self, self.format, top=top, left=left, height=height, width=width
92+
)
93+
return BoundingBox.new_like(self, output, image_size=image_size)
9594

9695
def center_crop(self, output_size: List[int]) -> BoundingBox:
97-
output = self._F.center_crop_bounding_box(
96+
output, image_size = self._F.center_crop_bounding_box(
9897
self, format=self.format, image_size=self.image_size, output_size=output_size
9998
)
100-
if isinstance(output_size, int):
101-
output_size = [output_size]
102-
image_size = (output_size[0], output_size[0]) if len(output_size) == 1 else (output_size[0], output_size[1])
10399
return BoundingBox.new_like(self, output, image_size=image_size)
104100

105101
def resized_crop(
@@ -112,29 +108,19 @@ def resized_crop(
112108
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
113109
antialias: bool = False,
114110
) -> BoundingBox:
115-
output = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
116-
image_size = (size[0], size[0]) if len(size) == 1 else (size[0], size[1])
117-
return BoundingBox.new_like(self, output, image_size=image_size, dtype=output.dtype)
111+
output, image_size = self._F.resized_crop_bounding_box(self, self.format, top, left, height, width, size=size)
112+
return BoundingBox.new_like(self, output, image_size=image_size)
118113

119114
def pad(
120115
self,
121116
padding: Union[int, Sequence[int]],
122117
fill: FillTypeJIT = None,
123118
padding_mode: str = "constant",
124119
) -> BoundingBox:
125-
# This cast does Sequence[int] -> List[int] and is required to make mypy happy
126-
if not isinstance(padding, int):
127-
padding = list(padding)
128-
129-
output = self._F.pad_bounding_box(self, format=self.format, padding=padding, padding_mode=padding_mode)
130-
131-
# Update output image size:
132-
left, right, top, bottom = self._F._geometry._parse_pad_padding(padding)
133-
height, width = self.image_size
134-
height += top + bottom
135-
width += left + right
136-
137-
return BoundingBox.new_like(self, output, image_size=(height, width))
120+
output, image_size = self._F.pad_bounding_box(
121+
self, format=self.format, image_size=self.image_size, padding=padding, padding_mode=padding_mode
122+
)
123+
return BoundingBox.new_like(self, output, image_size=image_size)
138124

139125
def rotate(
140126
self,
@@ -144,23 +130,10 @@ def rotate(
144130
fill: FillTypeJIT = None,
145131
center: Optional[List[float]] = None,
146132
) -> BoundingBox:
147-
output = self._F.rotate_bounding_box(
133+
output, image_size = self._F.rotate_bounding_box(
148134
self, format=self.format, image_size=self.image_size, angle=angle, expand=expand, center=center
149135
)
150-
image_size = self.image_size
151-
if expand:
152-
# The way we recompute image_size is not optimal due to redundant computations of
153-
# - rotation matrix (_get_inverse_affine_matrix)
154-
# - points dot matrix (_compute_affine_output_size)
155-
# Alternatively, we could return new image size by self._F.rotate_bounding_box
156-
height, width = image_size
157-
rotation_matrix = self._F._geometry._get_inverse_affine_matrix(
158-
[0.0, 0.0], angle, [0.0, 0.0], 1.0, [0.0, 0.0]
159-
)
160-
new_width, new_height = self._F._geometry._FT._compute_affine_output_size(rotation_matrix, width, height)
161-
image_size = (new_height, new_width)
162-
163-
return BoundingBox.new_like(self, output, dtype=output.dtype, image_size=image_size)
136+
return BoundingBox.new_like(self, output, image_size=image_size)
164137

165138
def affine(
166139
self,

0 commit comments

Comments
 (0)