Skip to content

Commit 10d47a6

Browse files
authored
[prototype] Speed up adjust_contrast_image_tensor (#6933)
* Avoid double casting on adjust_contrast * Handle properly ints.
1 parent f32600b commit 10d47a6

File tree

2 files changed

+13
-6
lines changed

2 files changed

+13
-6
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,14 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
7979
c = image.shape[-3]
8080
if c not in [1, 3]:
8181
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
82-
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
83-
grayscale_image = _rgb_to_gray(image) if c == 3 else image
84-
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
82+
fp = image.is_floating_point()
83+
if c == 3:
84+
grayscale_image = _rgb_to_gray(image, cast=False)
85+
if not fp:
86+
grayscale_image = grayscale_image.floor_()
87+
else:
88+
grayscale_image = image if fp else image.to(torch.float32)
89+
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
8590
return _blend(image, mean, contrast_factor)
8691

8792

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,12 @@ def _gray_to_rgb(grayscale: torch.Tensor) -> torch.Tensor:
213213
return grayscale.repeat(repeats)
214214

215215

216-
def _rgb_to_gray(image: torch.Tensor) -> torch.Tensor:
216+
def _rgb_to_gray(image: torch.Tensor, cast: bool = True) -> torch.Tensor:
217217
r, g, b = image.unbind(dim=-3)
218-
l_img = (0.2989 * r).add_(g, alpha=0.587).add_(b, alpha=0.114)
219-
l_img = l_img.to(image.dtype).unsqueeze(dim=-3)
218+
l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114)
219+
if cast:
220+
l_img = l_img.to(image.dtype)
221+
l_img = l_img.unsqueeze(dim=-3)
220222
return l_img
221223

222224

0 commit comments

Comments
 (0)