1
1
import torch
2
2
from torch .nn .functional import conv2d
3
3
from torchvision .prototype import features
4
- from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
4
+ from torchvision .transforms import functional_pil as _FP
5
+ from torchvision .transforms .functional_tensor import _max_value
5
6
6
7
from ._meta import _num_value_bits , _rgb_to_gray , convert_dtype_image_tensor
7
8
8
9
9
10
def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
10
11
ratio = float (ratio )
11
12
fp = image1 .is_floating_point ()
12
- bound = _FT . _max_value (image1 .dtype )
13
+ bound = _max_value (image1 .dtype )
13
14
output = image1 .mul (ratio ).add_ (image2 , alpha = (1.0 - ratio )).clamp_ (0 , bound )
14
15
return output if fp else output .to (image1 .dtype )
15
16
@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
18
19
if brightness_factor < 0 :
19
20
raise ValueError (f"brightness_factor ({ brightness_factor } ) is not non-negative." )
20
21
21
- _FT ._assert_channels (image , [1 , 3 ])
22
+ c = image .shape [- 3 ]
23
+ if c not in [1 , 3 ]:
24
+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3, but found { c } " )
22
25
23
26
fp = image .is_floating_point ()
24
- bound = _FT . _max_value (image .dtype )
27
+ bound = _max_value (image .dtype )
25
28
output = image .mul (brightness_factor ).clamp_ (0 , bound )
26
29
return output if fp else output .to (image .dtype )
27
30
@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
48
51
49
52
c = image .shape [- 3 ]
50
53
if c not in [1 , 3 ]:
51
- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
54
+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
52
55
53
56
if c == 1 : # Match PIL behaviour
54
57
return image
@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
82
85
83
86
c = image .shape [- 3 ]
84
87
if c not in [1 , 3 ]:
85
- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
88
+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
86
89
fp = image .is_floating_point ()
87
90
if c == 3 :
88
91
grayscale_image = _rgb_to_gray (image , cast = False )
@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
121
124
if image .numel () == 0 or height <= 2 or width <= 2 :
122
125
return image
123
126
124
- bound = _FT . _max_value (image .dtype )
127
+ bound = _max_value (image .dtype )
125
128
fp = image .is_floating_point ()
126
129
shape = image .shape
127
130
@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
248
251
249
252
c = image .shape [- 3 ]
250
253
if c not in [1 , 3 ]:
251
- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
254
+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
252
255
253
256
if c == 1 : # Match PIL behaviour
254
257
return image
@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
350
353
351
354
352
355
def solarize_image_tensor (image : torch .Tensor , threshold : float ) -> torch .Tensor :
353
- if threshold > _FT . _max_value (image .dtype ):
356
+ if threshold > _max_value (image .dtype ):
354
357
raise TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold } " )
355
358
356
359
return torch .where (image >= threshold , invert_image_tensor (image ), image )
@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
375
378
def autocontrast_image_tensor (image : torch .Tensor ) -> torch .Tensor :
376
379
c = image .shape [- 3 ]
377
380
if c not in [1 , 3 ]:
378
- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
381
+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
379
382
380
383
if image .numel () == 0 :
381
384
# exit earlier on empty images
382
385
return image
383
386
384
- bound = _FT . _max_value (image .dtype )
387
+ bound = _max_value (image .dtype )
385
388
fp = image .is_floating_point ()
386
389
float_image = image if fp else image .to (torch .float32 )
387
390
0 commit comments