Skip to content

Commit 8ef9f79

Browse files
authored
Merge branch 'main' into fix-audio-pts-of-video-api
2 parents 1bdc732 + 095437a commit 8ef9f79

19 files changed

+143
-114
lines changed

docs/source/transforms.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ you can use a functional transform to build transform classes with custom behavi
270270
erase
271271
five_crop
272272
gaussian_blur
273+
get_dimensions
273274
get_image_num_channels
274275
get_image_size
275276
hflip

references/classification/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ def forward(self, batch: Tensor, target: Tensor) -> Tuple[Tensor, Tensor]:
141141

142142
# Implemented as on cutmix paper, page 12 (with minor corrections on typos).
143143
lambda_param = float(torch._sample_dirichlet(torch.tensor([self.alpha, self.alpha]))[0])
144-
W, H = F.get_image_size(batch)
144+
_, H, W = F.get_dimensions(batch)
145145

146146
r_x = torch.randint(W, (1,))
147147
r_y = torch.randint(H, (1,))

references/detection/transforms.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def forward(
3434
if torch.rand(1) < self.p:
3535
image = F.hflip(image)
3636
if target is not None:
37-
width, _ = F.get_image_size(image)
37+
_, _, width = F.get_dimensions(image)
3838
target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
3939
if "masks" in target:
4040
target["masks"] = target["masks"].flip(-1)
@@ -107,7 +107,7 @@ def forward(
107107
elif image.ndimension() == 2:
108108
image = image.unsqueeze(0)
109109

110-
orig_w, orig_h = F.get_image_size(image)
110+
_, orig_h, orig_w = F.get_dimensions(image)
111111

112112
while True:
113113
# sample an option
@@ -192,7 +192,7 @@ def forward(
192192
if torch.rand(1) >= self.p:
193193
return image, target
194194

195-
orig_w, orig_h = F.get_image_size(image)
195+
_, orig_h, orig_w = F.get_dimensions(image)
196196

197197
r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
198198
canvas_width = int(orig_w * r)
@@ -270,7 +270,7 @@ def forward(
270270
image = self._contrast(image)
271271

272272
if r[6] < self.p:
273-
channels = F.get_image_num_channels(image)
273+
channels, _, _ = F.get_dimensions(image)
274274
permutation = torch.randperm(channels)
275275

276276
is_pil = F._is_pil_image(image)
@@ -317,7 +317,7 @@ def forward(
317317
elif image.ndimension() == 2:
318318
image = image.unsqueeze(0)
319319

320-
orig_width, orig_height = F.get_image_size(image)
320+
_, orig_height, orig_width = F.get_dimensions(image)
321321

322322
r = self.scale_range[0] + torch.rand(1) * (self.scale_range[1] - self.scale_range[0])
323323
new_width = int(self.target_size[1] * r)

test/test_functional_tensor.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929

3030

3131
@pytest.mark.parametrize("device", cpu_and_gpu())
32-
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels])
32+
@pytest.mark.parametrize("fn", [F.get_image_size, F.get_image_num_channels, F.get_dimensions])
3333
def test_image_sizes(device, fn):
3434
script_F = torch.jit.script(fn)
3535

@@ -1020,7 +1020,9 @@ def test_resized_crop(device, mode):
10201020
@pytest.mark.parametrize(
10211021
"func, args",
10221022
[
1023+
(F_t.get_dimensions, ()),
10231024
(F_t.get_image_size, ()),
1025+
(F_t.get_image_num_channels, ()),
10241026
(F_t.vflip, ()),
10251027
(F_t.hflip, ()),
10261028
(F_t.crop, (1, 2, 4, 5)),

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment
99
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
1010
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
11-
from ._meta_conversion import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
11+
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
1212
from ._misc import Identity, Normalize, ToDtype, Lambda
1313
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
1414
from ._type_conversion import DecodeImage, LabelToOneHot

torchvision/prototype/transforms/_augment.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype import features
88
from torchvision.prototype.transforms import Transform, functional as F
99

10-
from ._utils import query_image
10+
from ._utils import query_image, get_image_dimensions
1111

1212

1313
class RandomErasing(Transform):
@@ -41,8 +41,7 @@ def __init__(
4141

4242
def _get_params(self, sample: Any) -> Dict[str, Any]:
4343
image = query_image(sample)
44-
img_c = F.get_image_num_channels(image)
45-
img_w, img_h = F.get_image_size(image)
44+
img_c, img_h, img_w = get_image_dimensions(image)
4645

4746
if isinstance(self.value, (int, float)):
4847
value = [self.value]
@@ -138,7 +137,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
138137
lam = float(self._dist.sample(()))
139138

140139
image = query_image(sample)
141-
W, H = F.get_image_size(image)
140+
_, H, W = get_image_dimensions(image)
142141

143142
r_x = torch.randint(W, ())
144143
r_y = torch.randint(H, ())

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.prototype.transforms import Transform, InterpolationMode, AutoAugmentPolicy, functional as F
88
from torchvision.prototype.utils._internal import apply_recursively
99

10-
from ._utils import query_image
10+
from ._utils import query_image, get_image_dimensions
1111

1212
K = TypeVar("K")
1313
V = TypeVar("V")
@@ -47,7 +47,7 @@ def dispatch(
4747
return input
4848

4949
image = query_image(sample)
50-
num_channels = F.get_image_num_channels(image)
50+
num_channels, *_ = get_image_dimensions(image)
5151

5252
fill = self.fill
5353
if isinstance(fill, (int, float)):
@@ -160,8 +160,8 @@ class AutoAugment(_AutoAugmentBase):
160160
_AUGMENTATION_SPACE = {
161161
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
162162
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
163-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
164-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
163+
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
164+
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
165165
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
166166
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
167167
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
@@ -278,7 +278,7 @@ def forward(self, *inputs: Any) -> Any:
278278
sample = inputs if len(inputs) > 1 else inputs[0]
279279

280280
image = query_image(sample)
281-
image_size = F.get_image_size(image)
281+
_, height, width = get_image_dimensions(image)
282282

283283
policy = self._policies[int(torch.randint(len(self._policies), ()))]
284284

@@ -288,7 +288,7 @@ def forward(self, *inputs: Any) -> Any:
288288

289289
magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]
290290

291-
magnitudes = magnitudes_fn(10, image_size)
291+
magnitudes = magnitudes_fn(10, (height, width))
292292
if magnitudes is not None:
293293
magnitude = float(magnitudes[magnitude_idx])
294294
if signed and torch.rand(()) <= 0.5:
@@ -306,8 +306,8 @@ class RandAugment(_AutoAugmentBase):
306306
"Identity": (lambda num_bins, image_size: None, False),
307307
"ShearX": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
308308
"ShearY": (lambda num_bins, image_size: torch.linspace(0.0, 0.3, num_bins), True),
309-
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
310-
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
309+
"TranslateX": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[1], num_bins), True),
310+
"TranslateY": (lambda num_bins, image_size: torch.linspace(0.0, 150.0 / 331.0 * image_size[0], num_bins), True),
311311
"Rotate": (lambda num_bins, image_size: torch.linspace(0.0, 30.0, num_bins), True),
312312
"Brightness": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
313313
"Color": (lambda num_bins, image_size: torch.linspace(0.0, 0.9, num_bins), True),
@@ -334,12 +334,12 @@ def forward(self, *inputs: Any) -> Any:
334334
sample = inputs if len(inputs) > 1 else inputs[0]
335335

336336
image = query_image(sample)
337-
image_size = F.get_image_size(image)
337+
_, height, width = get_image_dimensions(image)
338338

339339
for _ in range(self.num_ops):
340340
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
341341

342-
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
342+
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
343343
if magnitudes is not None:
344344
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
345345
if signed and torch.rand(()) <= 0.5:
@@ -383,11 +383,11 @@ def forward(self, *inputs: Any) -> Any:
383383
sample = inputs if len(inputs) > 1 else inputs[0]
384384

385385
image = query_image(sample)
386-
image_size = F.get_image_size(image)
386+
_, height, width = get_image_dimensions(image)
387387

388388
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)
389389

390-
magnitudes = magnitudes_fn(self.num_magnitude_bins, image_size)
390+
magnitudes = magnitudes_fn(self.num_magnitude_bins, (height, width))
391391
if magnitudes is not None:
392392
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
393393
if signed and torch.rand(()) <= 0.5:

torchvision/prototype/transforms/_geometry.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
99
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1010

11-
from ._utils import query_image
11+
from ._utils import query_image, get_image_dimensions
1212

1313

1414
class HorizontalFlip(Transform):
@@ -109,7 +109,7 @@ def __init__(
109109

110110
def _get_params(self, sample: Any) -> Dict[str, Any]:
111111
image = query_image(sample)
112-
width, height = F.get_image_size(image)
112+
_, height, width = get_image_dimensions(image)
113113
area = height * width
114114

115115
log_ratio = torch.log(torch.tensor(self.ratio))

torchvision/prototype/transforms/_utils.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1-
from typing import Any, Optional, Union
1+
from typing import Any, Optional, Tuple, Union
22

33
import PIL.Image
44
import torch
55
from torchvision.prototype import features
66
from torchvision.prototype.utils._internal import query_recursively
77

8+
from .functional._meta import get_dimensions_image_tensor, get_dimensions_image_pil
9+
810

911
def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]:
1012
def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Image]]:
@@ -17,3 +19,16 @@ def fn(input: Any) -> Optional[Union[PIL.Image.Image, torch.Tensor, features.Ima
1719
return next(query_recursively(fn, sample))
1820
except StopIteration:
1921
raise TypeError("No image was found in the sample")
22+
23+
24+
def get_image_dimensions(image: Union[PIL.Image.Image, torch.Tensor, features.Image]) -> Tuple[int, int, int]:
25+
if isinstance(image, features.Image):
26+
channels = image.num_channels
27+
height, width = image.image_size
28+
elif isinstance(image, torch.Tensor):
29+
channels, height, width = get_dimensions_image_tensor(image)
30+
elif isinstance(image, PIL.Image.Image):
31+
channels, height, width = get_dimensions_image_pil(image)
32+
else:
33+
raise TypeError(f"unable to get image dimensions from object of type {type(image).__name__}")
34+
return channels, height, width

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from torchvision.transforms import InterpolationMode # usort: skip
2-
from ._utils import get_image_size, get_image_num_channels # usort: skip
3-
from ._meta_conversion import (
2+
from ._meta import (
43
convert_bounding_box_format,
54
convert_image_color_space_tensor,
65
convert_image_color_space_pil,

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,10 @@
55
import torch
66
from torchvision.prototype import features
77
from torchvision.prototype.transforms import InterpolationMode
8-
from torchvision.prototype.transforms.functional import get_image_size
98
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
109
from torchvision.transforms.functional import pil_modes_mapping, _get_inverse_affine_matrix
1110

12-
from ._meta_conversion import convert_bounding_box_format
11+
from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil
1312

1413

1514
horizontal_flip_image_tensor = _FT.hflip
@@ -40,8 +39,7 @@ def resize_image_tensor(
4039
antialias: Optional[bool] = None,
4140
) -> torch.Tensor:
4241
new_height, new_width = size
43-
old_width, old_height = _FT.get_image_size(image)
44-
num_channels = _FT.get_image_num_channels(image)
42+
num_channels, old_height, old_width = get_dimensions_image_tensor(image)
4543
batch_shape = image.shape[:-3]
4644
return _FT.resize(
4745
image.reshape((-1, num_channels, old_height, old_width)),
@@ -143,9 +141,9 @@ def affine_image_tensor(
143141

144142
center_f = [0.0, 0.0]
145143
if center is not None:
146-
width, height = get_image_size(img)
144+
_, height, width = get_dimensions_image_tensor(img)
147145
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
148-
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
146+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
149147

150148
translate_f = [1.0 * t for t in translate]
151149
matrix = _get_inverse_affine_matrix(center_f, angle, translate_f, scale, shear)
@@ -169,7 +167,7 @@ def affine_image_pil(
169167
# it is visually better to estimate the center without 0.5 offset
170168
# otherwise image rotated by 90 degrees is shifted vs output image of torch.rot90 or F_t.affine
171169
if center is None:
172-
width, height = get_image_size(img)
170+
_, height, width = get_dimensions_image_pil(img)
173171
center = [width * 0.5, height * 0.5]
174172
matrix = _get_inverse_affine_matrix(center, angle, translate, scale, shear)
175173

@@ -186,9 +184,9 @@ def rotate_image_tensor(
186184
) -> torch.Tensor:
187185
center_f = [0.0, 0.0]
188186
if center is not None:
189-
width, height = get_image_size(img)
187+
_, height, width = get_dimensions_image_tensor(img)
190188
# Center values should be in pixel coordinates but translated such that (0, 0) corresponds to image center.
191-
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, (width, height))]
189+
center_f = [1.0 * (c - s * 0.5) for c, s in zip(center, [width, height])]
192190

193191
# due to current incoherence of rotation angle direction between affine and rotate implementations
194192
# we need to set -angle.
@@ -262,13 +260,13 @@ def _center_crop_compute_crop_anchor(
262260

263261
def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch.Tensor:
264262
crop_height, crop_width = _center_crop_parse_output_size(output_size)
265-
image_width, image_height = get_image_size(img)
263+
_, image_height, image_width = get_dimensions_image_tensor(img)
266264

267265
if crop_height > image_height or crop_width > image_width:
268266
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
269267
img = pad_image_tensor(img, padding_ltrb, fill=0)
270268

271-
image_width, image_height = get_image_size(img)
269+
_, image_height, image_width = get_dimensions_image_tensor(img)
272270
if crop_width == image_width and crop_height == image_height:
273271
return img
274272

@@ -278,13 +276,13 @@ def center_crop_image_tensor(img: torch.Tensor, output_size: List[int]) -> torch
278276

279277
def center_crop_image_pil(img: PIL.Image.Image, output_size: List[int]) -> PIL.Image.Image:
280278
crop_height, crop_width = _center_crop_parse_output_size(output_size)
281-
image_width, image_height = get_image_size(img)
279+
_, image_height, image_width = get_dimensions_image_pil(img)
282280

283281
if crop_height > image_height or crop_width > image_width:
284282
padding_ltrb = _center_crop_compute_padding(crop_height, crop_width, image_height, image_width)
285283
img = pad_image_pil(img, padding_ltrb, fill=0)
286284

287-
image_width, image_height = get_image_size(img)
285+
_, image_height, image_width = get_dimensions_image_pil(img)
288286
if crop_width == image_width and crop_height == image_height:
289287
return img
290288

torchvision/prototype/transforms/functional/_meta_conversion.py renamed to torchvision/prototype/transforms/functional/_meta.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
from torchvision.transforms import functional_tensor as _FT, functional_pil as _FP
55

66

7+
get_dimensions_image_tensor = _FT.get_dimensions
8+
get_dimensions_image_pil = _FP.get_dimensions
9+
10+
711
def _xywh_to_xyxy(xywh: torch.Tensor) -> torch.Tensor:
812
xyxy = xywh.clone()
913
xyxy[..., 2:] += xyxy[..., :2]

torchvision/prototype/transforms/functional/_utils.py

Lines changed: 0 additions & 29 deletions
This file was deleted.

0 commit comments

Comments
 (0)