11import torch
22from torch .nn .functional import conv2d
33from torchvision .prototype import features
4- from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
4+ from torchvision .transforms import functional_pil as _FP
5+ from torchvision .transforms .functional_tensor import _max_value
56
67from ._meta import _num_value_bits , _rgb_to_gray , convert_dtype_image_tensor
78
89
910def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
1011 ratio = float (ratio )
1112 fp = image1 .is_floating_point ()
12- bound = _FT . _max_value (image1 .dtype )
13+ bound = _max_value (image1 .dtype )
1314 output = image1 .mul (ratio ).add_ (image2 , alpha = (1.0 - ratio )).clamp_ (0 , bound )
1415 return output if fp else output .to (image1 .dtype )
1516
@@ -18,10 +19,12 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
1819 if brightness_factor < 0 :
1920 raise ValueError (f"brightness_factor ({ brightness_factor } ) is not non-negative." )
2021
21- _FT ._assert_channels (image , [1 , 3 ])
22+ c = image .shape [- 3 ]
23+ if c not in [1 , 3 ]:
24+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3, but found { c } " )
2225
2326 fp = image .is_floating_point ()
24- bound = _FT . _max_value (image .dtype )
27+ bound = _max_value (image .dtype )
2528 output = image .mul (brightness_factor ).clamp_ (0 , bound )
2629 return output if fp else output .to (image .dtype )
2730
@@ -48,7 +51,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
4851
4952 c = image .shape [- 3 ]
5053 if c not in [1 , 3 ]:
51- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
54+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
5255
5356 if c == 1 : # Match PIL behaviour
5457 return image
@@ -82,7 +85,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
8285
8386 c = image .shape [- 3 ]
8487 if c not in [1 , 3 ]:
85- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
88+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
8689 fp = image .is_floating_point ()
8790 if c == 3 :
8891 grayscale_image = _rgb_to_gray (image , cast = False )
@@ -121,7 +124,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
121124 if image .numel () == 0 or height <= 2 or width <= 2 :
122125 return image
123126
124- bound = _FT . _max_value (image .dtype )
127+ bound = _max_value (image .dtype )
125128 fp = image .is_floating_point ()
126129 shape = image .shape
127130
@@ -248,7 +251,7 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
248251
249252 c = image .shape [- 3 ]
250253 if c not in [1 , 3 ]:
251- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
254+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
252255
253256 if c == 1 : # Match PIL behaviour
254257 return image
@@ -350,7 +353,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
350353
351354
352355def solarize_image_tensor (image : torch .Tensor , threshold : float ) -> torch .Tensor :
353- if threshold > _FT . _max_value (image .dtype ):
356+ if threshold > _max_value (image .dtype ):
354357 raise TypeError (f"Threshold should be less or equal the maximum value of the dtype, but got { threshold } " )
355358
356359 return torch .where (image >= threshold , invert_image_tensor (image ), image )
@@ -375,13 +378,13 @@ def solarize(inpt: features.InputTypeJIT, threshold: float) -> features.InputTyp
375378def autocontrast_image_tensor (image : torch .Tensor ) -> torch .Tensor :
376379 c = image .shape [- 3 ]
377380 if c not in [1 , 3 ]:
378- raise TypeError (f"Input image tensor permitted channel values are { [ 1 , 3 ] } , but found { c } " )
381+ raise TypeError (f"Input image tensor permitted channel values are 1 or 3 , but found { c } " )
379382
380383 if image .numel () == 0 :
381384 # exit earlier on empty images
382385 return image
383386
384- bound = _FT . _max_value (image .dtype )
387+ bound = _max_value (image .dtype )
385388 fp = image .is_floating_point ()
386389 float_image = image if fp else image .to (torch .float32 )
387390
0 commit comments