22from torchvision .prototype import features
33from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
44
5- from ._meta import get_dimensions_image_tensor
5+ from ._meta import _rgb_to_gray , get_dimensions_image_tensor , get_num_channels_image_tensor
6+
7+
8+ def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
9+ ratio = float (ratio )
10+ fp = image1 .is_floating_point ()
11+ bound = 1.0 if fp else 255.0
12+ output = image1 .mul (ratio ).add_ (image2 , alpha = (1.0 - ratio )).clamp_ (0 , bound )
13+ return output if fp else output .to (image1 .dtype )
14+
15+
16+ def adjust_brightness_image_tensor (image : torch .Tensor , brightness_factor : float ) -> torch .Tensor :
17+ if brightness_factor < 0 :
18+ raise ValueError (f"brightness_factor ({ brightness_factor } ) is not non-negative." )
19+
20+ _FT ._assert_channels (image , [1 , 3 ])
21+
22+ fp = image .is_floating_point ()
23+ bound = 1.0 if fp else 255.0
24+ output = image .mul (brightness_factor ).clamp_ (0 , bound )
25+ return output if fp else output .to (image .dtype )
26+
627
7- adjust_brightness_image_tensor = _FT .adjust_brightness
828adjust_brightness_image_pil = _FP .adjust_brightness
929
1030
@@ -21,7 +41,20 @@ def adjust_brightness(inpt: features.InputTypeJIT, brightness_factor: float) ->
2141 return adjust_brightness_image_pil (inpt , brightness_factor = brightness_factor )
2242
2343
24- adjust_saturation_image_tensor = _FT .adjust_saturation
44+ def adjust_saturation_image_tensor (image : torch .Tensor , saturation_factor : float ) -> torch .Tensor :
45+ if saturation_factor < 0 :
46+ raise ValueError (f"saturation_factor ({ saturation_factor } ) is not non-negative." )
47+
48+ c = get_num_channels_image_tensor (image )
49+ if c not in [1 , 3 ]:
50+ raise TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} , but found { c } " )
51+
52+ if c == 1 : # Match PIL behaviour
53+ return image
54+
55+ return _blend (image , _rgb_to_gray (image ), saturation_factor )
56+
57+
2558adjust_saturation_image_pil = _FP .adjust_saturation
2659
2760
@@ -38,7 +71,19 @@ def adjust_saturation(inpt: features.InputTypeJIT, saturation_factor: float) ->
3871 return adjust_saturation_image_pil (inpt , saturation_factor = saturation_factor )
3972
4073
41- adjust_contrast_image_tensor = _FT .adjust_contrast
74+ def adjust_contrast_image_tensor (image : torch .Tensor , contrast_factor : float ) -> torch .Tensor :
75+ if contrast_factor < 0 :
76+ raise ValueError (f"contrast_factor ({ contrast_factor } ) is not non-negative." )
77+
78+ c = get_num_channels_image_tensor (image )
79+ if c not in [1 , 3 ]:
80+ raise TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} , but found { c } " )
81+ dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
82+ grayscale_image = _rgb_to_gray (image ) if c == 3 else image
83+ mean = torch .mean (grayscale_image .to (dtype ), dim = (- 3 , - 2 , - 1 ), keepdim = True )
84+ return _blend (image , mean , contrast_factor )
85+
86+
4287adjust_contrast_image_pil = _FP .adjust_contrast
4388
4489
@@ -74,7 +119,7 @@ def adjust_sharpness_image_tensor(image: torch.Tensor, sharpness_factor: float)
74119 else :
75120 needs_unsquash = False
76121
77- output = _FT . _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
122+ output = _blend (image , _FT ._blurred_degenerate_image (image ), sharpness_factor )
78123
79124 if needs_unsquash :
80125 output = output .reshape (shape )
@@ -183,13 +228,13 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
183228 return autocontrast_image_pil (inpt )
184229
185230
186- def _equalize_image_tensor_vec (img : torch .Tensor ) -> torch .Tensor :
187- # input img shape should be [N, H, W]
188- shape = img .shape
231+ def _equalize_image_tensor_vec (image : torch .Tensor ) -> torch .Tensor :
232+ # input image shape should be [N, H, W]
233+ shape = image .shape
189234 # Compute image histogram:
190- flat_img = img .flatten (start_dim = 1 ).to (torch .long ) # -> [N, H * W]
191- hist = flat_img .new_zeros (shape [0 ], 256 )
192- hist .scatter_add_ (dim = 1 , index = flat_img , src = flat_img .new_ones (1 ).expand_as (flat_img ))
235+ flat_image = image .flatten (start_dim = 1 ).to (torch .long ) # -> [N, H * W]
236+ hist = flat_image .new_zeros (shape [0 ], 256 )
237+ hist .scatter_add_ (dim = 1 , index = flat_image , src = flat_image .new_ones (1 ).expand_as (flat_image ))
193238
194239 # Compute image cdf
195240 chist = hist .cumsum_ (dim = 1 )
@@ -213,7 +258,7 @@ def _equalize_image_tensor_vec(img: torch.Tensor) -> torch.Tensor:
213258 zeros = lut .new_zeros ((1 , 1 )).expand (shape [0 ], 1 )
214259 lut = torch .cat ([zeros , lut [:, :- 1 ]], dim = 1 )
215260
216- return torch .where ((step == 0 ).unsqueeze (- 1 ), img , lut .gather (dim = 1 , index = flat_img ).reshape_as (img ))
261+ return torch .where ((step == 0 ).unsqueeze (- 1 ), image , lut .gather (dim = 1 , index = flat_image ).reshape_as (image ))
217262
218263
219264def equalize_image_tensor (image : torch .Tensor ) -> torch .Tensor :
0 commit comments