Skip to content

Commit 6e203b4

Browse files
authored
[prototype] Rewrite the meta dimension methods (#6722)
* Rewrite `get_dimensions`, `get_num_channels` and `get_spatial_size` * Remove `get_chw` * Remove comments * Make `get_spatial_size` support non-image input * Reduce the unnecessary use of `get_dimensions*` * Fix linters * Fix merge bug * Linter * Fix linter
1 parent 4c049ca commit 6e203b4

File tree

6 files changed

+71
-37
lines changed

6 files changed

+71
-37
lines changed

torchvision/prototype/features/_mask.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from __future__ import annotations
22

3-
from typing import Any, List, Optional, Union
3+
from typing import Any, cast, List, Optional, Tuple, Union
44

55
import torch
66
from torchvision.transforms import InterpolationMode
@@ -32,6 +32,10 @@ def wrap_like(
3232
) -> Mask:
3333
return cls._wrap(tensor)
3434

35+
@property
36+
def image_size(self) -> Tuple[int, int]:
37+
return cast(Tuple[int, int], tuple(self.shape[-2:]))
38+
3539
def horizontal_flip(self) -> Mask:
3640
output = self._F.horizontal_flip_mask(self)
3741
return Mask.wrap_like(self, output)

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils._pytree import tree_flatten, tree_unflatten
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
10-
from torchvision.prototype.transforms.functional._meta import get_chw
10+
from torchvision.prototype.transforms.functional._meta import get_spatial_size
1111

1212
from ._utils import _isinstance, _setup_fill_arg
1313

@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278
sample = inputs if len(inputs) > 1 else inputs[0]
279279

280280
id, image_or_video = self._extract_image_or_video(sample)
281-
_, height, width = get_chw(image_or_video)
281+
height, width = get_spatial_size(image_or_video)
282282

283283
policy = self._policies[int(torch.randint(len(self._policies), ()))]
284284

@@ -349,7 +349,7 @@ def forward(self, *inputs: Any) -> Any:
349349
sample = inputs if len(inputs) > 1 else inputs[0]
350350

351351
id, image_or_video = self._extract_image_or_video(sample)
352-
_, height, width = get_chw(image_or_video)
352+
height, width = get_spatial_size(image_or_video)
353353

354354
for _ in range(self.num_ops):
355355
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
@@ -403,7 +403,7 @@ def forward(self, *inputs: Any) -> Any:
403403
sample = inputs if len(inputs) > 1 else inputs[0]
404404

405405
id, image_or_video = self._extract_image_or_video(sample)
406-
_, height, width = get_chw(image_or_video)
406+
height, width = get_spatial_size(image_or_video)
407407

408408
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
409409

@@ -473,7 +473,7 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
473473
def forward(self, *inputs: Any) -> Any:
474474
sample = inputs if len(inputs) > 1 else inputs[0]
475475
id, orig_image_or_video = self._extract_image_or_video(sample)
476-
_, height, width = get_chw(orig_image_or_video)
476+
height, width = get_spatial_size(orig_image_or_video)
477477

478478
if isinstance(orig_image_or_video, torch.Tensor):
479479
image_or_video = orig_image_or_video

torchvision/prototype/transforms/_utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from torchvision.prototype import features
1111
from torchvision.prototype.features._feature import FillType
1212

13-
from torchvision.prototype.transforms.functional._meta import get_chw
13+
from torchvision.prototype.transforms.functional._meta import get_dimensions
1414
from torchvision.transforms.transforms import _check_sequence_input, _setup_angle, _setup_size # noqa: F401
1515

1616
from typing_extensions import Literal
@@ -80,15 +80,16 @@ def query_bounding_box(sample: Any) -> features.BoundingBox:
8080
def query_chw(sample: Any) -> Tuple[int, int, int]:
8181
flat_sample, _ = tree_flatten(sample)
8282
chws = {
83-
get_chw(item)
83+
tuple(get_dimensions(item))
8484
for item in flat_sample
8585
if isinstance(item, (features.Image, PIL.Image.Image, features.Video)) or features.is_simple_tensor(item)
8686
}
8787
if not chws:
8888
raise TypeError("No image or video was found in the sample")
8989
elif len(chws) > 1:
9090
raise ValueError(f"Found multiple CxHxW dimensions in the sample: {sequence_to_str(sorted(chws))}")
91-
return chws.pop()
91+
c, h, w = chws.pop()
92+
return c, h, w
9293

9394

9495
def _isinstance(obj: Any, types_or_checks: Tuple[Union[Type, Callable[[Any], bool]], ...]) -> bool:

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,15 @@
88
convert_color_space_image_pil,
99
convert_color_space_video,
1010
convert_color_space,
11+
get_dimensions_image_tensor,
12+
get_dimensions_image_pil,
1113
get_dimensions,
1214
get_image_num_channels,
15+
get_num_channels_image_tensor,
16+
get_num_channels_image_pil,
1317
get_num_channels,
18+
get_spatial_size_image_tensor,
19+
get_spatial_size_image_pil,
1420
get_spatial_size,
1521
) # usort: skip
1622

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,12 @@
2121
interpolate,
2222
)
2323

24-
from ._meta import convert_format_bounding_box, get_dimensions_image_pil, get_dimensions_image_tensor
24+
from ._meta import (
25+
convert_format_bounding_box,
26+
get_dimensions_image_tensor,
27+
get_spatial_size_image_pil,
28+
get_spatial_size_image_tensor,
29+
)
2530

2631
horizontal_flip_image_tensor = _FT.hflip
2732
horizontal_flip_image_pil = _FP.hflip
@@ -323,7 +328,7 @@ def affine_image_pil(
323328
# it is visually better to estimate the center without 0.5 offset
324329
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
325330
if center is None:
326-
_, height, width = get_dimensions_image_pil(image)
331+
height, width = get_spatial_size_image_pil(image)
327332
center = [width * 0.5, height * 0.5]
328333
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
329334

@@ -1189,13 +1194,13 @@ def _center_crop_compute_crop_anchor(
11891194

11901195
def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> torch.Tensor:
11911196
crop_height, crop_width = _center_crop_parse_output_size(output_size)
1192-
_, image_height, image_width = get_dimensions_image_tensor(image)
1197+
image_height, image_width = get_spatial_size_image_tensor(image)
11931198

11941199
if crop_height > image_height or crop_width > image_width:
11951200
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
11961201
image = pad_image_tensor(image, padding_ltrb, fill=0)
11971202

1198-
_, image_height, image_width = get_dimensions_image_tensor(image)
1203+
image_height, image_width = get_spatial_size_image_tensor(image)
11991204
if crop_width == image_width and crop_height == image_height:
12001205
return image
12011206

@@ -1206,13 +1211,13 @@ def center_crop_image_tensor(image: torch.Tensor, output_size: List[int]) -> tor
12061211
@torch.jit.unused
12071212
def center_crop_image_pil(image: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
12081213
crop_height, crop_width = _center_crop_parse_output_size(output_size)
1209-
_, image_height, image_width = get_dimensions_image_pil(image)
1214+
image_height, image_width = get_spatial_size_image_pil(image)
12101215

12111216
if crop_height > image_height or crop_width > image_width:
12121217
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
12131218
image = pad_image_pil(image, padding_ltrb, fill=0)
12141219

1215-
_, image_height, image_width = get_dimensions_image_pil(image)
1220+
image_height, image_width = get_spatial_size_image_pil(image)
12161221
if crop_width == image_width and crop_height == image_height:
12171222
return image
12181223

@@ -1365,7 +1370,7 @@ def five_crop_image_tensor(
13651370
image: torch.Tensor, size: List[int]
13661371
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
13671372
crop_height, crop_width = _parse_five_crop_size(size)
1368-
_, image_height, image_width = get_dimensions_image_tensor(image)
1373+
image_height, image_width = get_spatial_size_image_tensor(image)
13691374

13701375
if crop_width > image_width or crop_height > image_height:
13711376
msg = "Requested crop size {} is bigger than input size {}"
@@ -1385,7 +1390,7 @@ def five_crop_image_pil(
13851390
image: PIL.Image.Image, size: List[int]
13861391
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
13871392
crop_height, crop_width = _parse_five_crop_size(size)
1388-
_, image_height, image_width = get_dimensions_image_pil(image)
1393+
image_height, image_width = get_spatial_size_image_pil(image)
13891394

13901395
if crop_width > image_width or crop_height > image_height:
13911396
msg = "Requested crop size {} is bigger than input size {}"

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 38 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,48 +6,66 @@
66
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
77
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
88

9+
910
get_dimensions_image_tensor = _FT.get_dimensions
1011
get_dimensions_image_pil = _FP.get_dimensions
1112

1213

13-
# TODO: Should this be prefixed with `_` similar to other methods that don't get exposed by init?
14-
def get_chw(image: features.ImageOrVideoTypeJIT) -> Tuple[int, int, int]:
14+
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
1515
if isinstance(image, torch.Tensor) and (
1616
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
1717
):
18-
channels, height, width = get_dimensions_image_tensor(image)
18+
return get_dimensions_image_tensor(image)
1919
elif isinstance(image, (features.Image, features.Video)):
2020
channels = image.num_channels
2121
height, width = image.image_size
22-
else: # isinstance(image, PIL.Image.Image)
23-
channels, height, width = get_dimensions_image_pil(image)
24-
return channels, height, width
25-
26-
27-
# The three functions below are here for BC. Whether we want to have two different kernels and how they and the
28-
# compound version should be named is still under discussion: https://github.com/pytorch/vision/issues/6491
29-
# Given that these kernels should also support boxes, masks, and videos, it is unlikely that there name will stay.
30-
# They will either be deprecated or simply aliased to the new kernels if we have reached consensus about the issue
31-
# detailed above.
22+
return [channels, height, width]
23+
else:
24+
return get_dimensions_image_pil(image)
3225

3326

34-
def get_dimensions(image: features.ImageOrVideoTypeJIT) -> List[int]:
35-
return list(get_chw(image))
27+
get_num_channels_image_tensor = _FT.get_image_num_channels
28+
get_num_channels_image_pil = _FP.get_image_num_channels
3629

3730

3831
def get_num_channels(image: features.ImageOrVideoTypeJIT) -> int:
39-
num_channels, *_ = get_chw(image)
40-
return num_channels
32+
if isinstance(image, torch.Tensor) and (
33+
torch.jit.is_scripting() or not isinstance(image, (features.Image, features.Video))
34+
):
35+
return _FT.get_image_num_channels(image)
36+
elif isinstance(image, (features.Image, features.Video)):
37+
return image.num_channels
38+
else:
39+
return _FP.get_image_num_channels(image)
4140

4241

4342
# We changed the names to ensure it can be used not only for images but also videos. Thus, we just alias it without
4443
# deprecating the old names.
4544
get_image_num_channels = get_num_channels
4645

4746

48-
def get_spatial_size(image: features.ImageOrVideoTypeJIT) -> List[int]:
49-
_, *size = get_chw(image)
50-
return size
47+
def get_spatial_size_image_tensor(image: torch.Tensor) -> List[int]:
48+
width, height = _FT.get_image_size(image)
49+
return [height, width]
50+
51+
52+
@torch.jit.unused
53+
def get_spatial_size_image_pil(image: PIL.Image.Image) -> List[int]:
54+
width, height = _FP.get_image_size(image)
55+
return [height, width]
56+
57+
58+
def get_spatial_size(inpt: features.InputTypeJIT) -> List[int]:
59+
if isinstance(inpt, torch.Tensor) and (torch.jit.is_scripting() or not isinstance(inpt, features._Feature)):
60+
return get_spatial_size_image_tensor(inpt)
61+
elif isinstance(inpt, features._Feature):
62+
image_size = getattr(inpt, "image_size", None)
63+
if image_size is not None:
64+
return list(image_size)
65+
else:
66+
raise ValueError(f"Type {inpt.__class__} doesn't have spatial size.")
67+
else:
68+
return get_spatial_size_image_pil(inpt)
5169

5270

5371
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)