22from torchvision .prototype import features
33from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
44
5- from ._meta import get_dimensions_image_tensor , get_num_channels_image_tensor , rgb_to_grayscale
5+ from ._meta import _rgb_to_gray , get_dimensions_image_tensor , get_num_channels_image_tensor
66
77
88def _blend (image1 : torch .Tensor , image2 : torch .Tensor , ratio : float ) -> torch .Tensor :
@@ -52,7 +52,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
5252 if c == 1 : # Match PIL behaviour
5353 return image
5454
55- return _blend (image , rgb_to_grayscale (image ), saturation_factor )
55+ return _blend (image , _rgb_to_gray (image ), saturation_factor )
5656
5757
5858adjust_saturation_image_pil = _FP .adjust_saturation
@@ -79,7 +79,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
7979 if c not in [1 , 3 ]:
8080 raise TypeError (f"Input image tensor permitted channel values are { [1 , 3 ]} , but found { c } " )
8181 dtype = image .dtype if torch .is_floating_point (image ) else torch .float32
82- grayscale_image = _FT . rgb_to_grayscale (image ) if c == 3 else image
82+ grayscale_image = _rgb_to_gray (image ) if c == 3 else image
8383 mean = torch .mean (grayscale_image .to (dtype ), dim = (- 3 , - 2 , - 1 ), keepdim = True )
8484 return _blend (image , mean , contrast_factor )
8585
0 commit comments