Skip to content

Commit 4117957

Browse files
committed
PR review
1 parent b7fdd39 commit 4117957

File tree

2 files changed

+7
-27
lines changed

2 files changed

+7
-27
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torchvision.prototype import features
33
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
44

5-
from ._meta import get_dimensions_image_tensor, get_num_channels_image_tensor, rgb_to_grayscale
5+
from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor
66

77

88
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
@@ -52,7 +52,7 @@ def adjust_saturation_image_tensor(image: torch.Tensor, saturation_factor: float
5252
if c == 1: # Match PIL behaviour
5353
return image
5454

55-
return _blend(image, rgb_to_grayscale(image), saturation_factor)
55+
return _blend(image, _rgb_to_gray(image), saturation_factor)
5656

5757

5858
adjust_saturation_image_pil = _FP.adjust_saturation
@@ -79,7 +79,7 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
7979
if c not in [1, 3]:
8080
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
8181
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
82-
grayscale_image = _FT.rgb_to_grayscale(image) if c == 3 else image
82+
grayscale_image = _rgb_to_gray(image) if c == 3 else image
8383
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
8484
return _blend(image, mean, contrast_factor)
8585

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -184,33 +184,13 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
184184
return grayscale.repeat(repeats)
185185

186186

187-
def rgb_to_grayscale(image: torch.Tensor, num_output_channels: int = 1) -> torch.Tensor:
188-
if image.ndim < 3:
189-
raise TypeError(f"Input image tensor should have at least 3 dimensions, but found {image.ndim}")
190-
191-
c = image.shape[-3]
192-
if c not in [1, 3]:
193-
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
194-
195-
if num_output_channels not in (1, 3):
196-
raise ValueError("num_output_channels should be either 1 or 3")
197-
198-
if c == 3:
199-
r, g, b = image.unbind(dim=-3)
200-
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
201-
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
202-
else:
203-
l_img = image.clone()
204-
205-
if num_output_channels == 3:
206-
return l_img.expand(image.shape)
207-
187+
def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
188+
r, g, b = image.unbind(dim=-3)
189+
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
190+
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
208191
return l_img
209192

210193

211-
_rgb_to_gray = rgb_to_grayscale
212-
213-
214194
def convert_color_space_image_tensor(
215195
image: torch.Tensor, old_color_space: ColorSpace, new_color_space: ColorSpace, copy: bool = True
216196
) -> torch.Tensor:

0 commit comments

Comments
 (0)