Skip to content

Commit 4df1a85

Browse files
authored
[prototype] Remove _FT aliases from functional (#6983)
* Remove `_FT` usages from functional * Update error messages
1 parent 50b77fa commit 4df1a85

File tree

4 files changed

+41
-22
lines changed

4 files changed

+41
-22
lines changed

torchvision/prototype/transforms/functional/_augment.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,17 @@
44

55
import torch
66
from torchvision.prototype import features
7-
from torchvision.transforms import functional_tensor as _FT
87
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
98

10-
erase_image_tensor = _FT.erase
9+
10+
def erase_image_tensor(
11+
image: torch.Tensor, i: int, j: int, h: int, w: int, v: torch.Tensor, inplace: bool = False
12+
) -> torch.Tensor:
13+
if not inplace:
14+
image = image.clone()
15+
16+
image[..., i : i + h, j : j + w] = v
17+
return image
1118

1219

1320
@torch.jit.unused

torchvision/prototype/transforms/functional/_color.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,16 @@
11
import torch
22
from torch.nn.functional import conv2d
33
from torchvision.prototype import features
4-
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
4+
from torchvision.transforms import functional_pil as _FP
5+
from torchvision.transforms.functional_tensor import _max_value
56

67
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
78

89

910
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
1011
ratio = float(ratio)
1112
fp = image1.is_floating_point()
12-
bound = _FT._max_value(image1.dtype)
13+
bound = _max_value(image1.dtype)
1314
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
1415
return output if fp else output.to(image1.dtype)
1516

@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
1819
if brightness_factor < 0:
1920
raise ValueError(f"brightness_factor ({brightness_factor}) is not non-negative.")
2021

21-
_FT._assert_channels(image, [1, 3])
22+
c = image.shape[-3]
23+
if c not in [1, 3]:
24+
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
2225

2326
fp = image.is_floating_point()
24-
bound = _FT._max_value(image.dtype)
27+
bound = _max_value(image.dtype)
2528
output = image.mul(brightness_factor).clamp_(0, bound)
2629
return output if fp else output.to(image.dtype)
2730

@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
4851

4952
c = image.shape[-3]
5053
if c not in [1, 3]:
51-
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
54+
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
5255

5356
if c == 1: # Match PIL behaviour
5457
return image
@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
8285

8386
c = image.shape[-3]
8487
if c not in [1, 3]:
85-
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
88+
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
8689
fp = image.is_floating_point()
8790
if c == 3:
8891
grayscale_image = _rgb_to_gray(image, cast=False)
@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
121124
if image.numel() == 0 or height <= 2 or width <= 2:
122125
return image
123126

124-
bound = _FT._max_value(image.dtype)
127+
bound = _max_value(image.dtype)
125128
fp = image.is_floating_point()
126129
shape = image.shape
127130

@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
248251

249252
c = image.shape[-3]
250253
if c not in [1, 3]:
251-
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
254+
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
252255

253256
if c == 1: # Match PIL behaviour
254257
return image
@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
350353

351354

352355
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
353-
if threshold > _FT._max_value(image.dtype):
356+
if threshold > _max_value(image.dtype):
354357
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
355358

356359
return torch.where(image >= threshold, invert_image_tensor(image), image)
@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
375378
def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
376379
c = image.shape[-3]
377380
if c not in [1, 3]:
378-
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
381+
raise TypeError(f"Input image tensor permitted channel values are 1 or 3, but found {c}")
379382

380383
if image.numel() == 0:
381384
# exit earlier on empty images
382385
return image
383386

384-
bound = _FT._max_value(image.dtype)
387+
bound = _max_value(image.dtype)
385388
fp = image.is_floating_point()
386389
float_image = image if fp else image.to(torch.float32)
387390

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torch.nn.functional import grid_sample, interpolate, pad as torch_pad
99

1010
from torchvision.prototype import features
11-
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
11+
from torchvision.transforms import functional_pil as _FP
1212
from torchvision.transforms.functional import (
1313
_compute_resized_output_size as __compute_resized_output_size,
1414
_get_perspective_coeffs,
@@ -17,10 +17,15 @@
1717
pil_to_tensor,
1818
to_pil_image,
1919
)
20+
from torchvision.transforms.functional_tensor import _pad_symmetric
2021

2122
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
2223

23-
horizontal_flip_image_tensor = _FT.hflip
24+
25+
def horizontal_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
26+
return image.flip(-1)
27+
28+
2429
horizontal_flip_image_pil = _FP.hflip
2530

2631

@@ -58,7 +63,10 @@ def horizontal_flip(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
5863
return horizontal_flip_image_pil(inpt)
5964

6065

61-
vertical_flip_image_tensor = _FT.vflip
66+
def vertical_flip_image_tensor(image: torch.Tensor) -> torch.Tensor:
67+
return image.flip(-2)
68+
69+
6270
vertical_flip_image_pil = _FP.vflip
6371

6472

@@ -975,7 +983,7 @@ def _pad_with_scalar_fill(
975983
if needs_cast:
976984
image = image.to(dtype)
977985
else: # padding_mode == "symmetric"
978-
image = _FT._pad_symmetric(image, torch_padding)
986+
image = _pad_symmetric(image, torch_padding)
979987

980988
new_height, new_width = image.shape[-2:]
981989

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44
import torch
55
from torchvision.prototype import features
66
from torchvision.prototype.features import BoundingBoxFormat, ColorSpace
7-
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
7+
from torchvision.transforms import functional_pil as _FP
8+
from torchvision.transforms.functional_tensor import _max_value
89

910

1011
def get_dimensions_image_tensor(image: torch.Tensor) -> List[int]:
@@ -193,7 +194,7 @@ def clamp_bounding_box(
193194

194195
def _strip_alpha(image: torch.Tensor) -> torch.Tensor:
195196
image, alpha = torch.tensor_split(image, indices=(-1,), dim=-3)
196-
if not torch.all(alpha == _FT._max_value(alpha.dtype)):
197+
if not torch.all(alpha == _max_value(alpha.dtype)):
197198
raise RuntimeError(
198199
"Stripping the alpha channel if it contains values other than the max value is not supported."
199200
)
@@ -204,7 +205,7 @@ def _add_alpha(image: torch.Tensor, alpha: Optional[torch.Tensor] = None) -> tor
204205
if alpha is None:
205206
shape = list(image.shape)
206207
shape[-3] = 1
207-
alpha = torch.full(shape, _FT._max_value(image.dtype), dtype=image.dtype, device=image.device)
208+
alpha = torch.full(shape, _max_value(image.dtype), dtype=image.dtype, device=image.device)
208209
return torch.cat((image, alpha), dim=-3)
209210

210211

@@ -363,14 +364,14 @@ def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.f
363364
# Instead, we can also multiply by the maximum value plus something close to `1`. See
364365
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
365366
eps = 1e-3
366-
max_value = float(_FT._max_value(dtype))
367+
max_value = float(_max_value(dtype))
367368
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
368369
# discrete set `{0, 1}`.
369370
return image.mul(max_value + 1.0 - eps).to(dtype)
370371
else:
371372
# int to float
372373
if float_output:
373-
return image.to(dtype).mul_(1.0 / _FT._max_value(image.dtype))
374+
return image.to(dtype).mul_(1.0 / _max_value(image.dtype))
374375

375376
# int to int
376377
num_value_bits_input = _num_value_bits(image.dtype)

0 commit comments

Comments
 (0)