Skip to content

Commit a82cf8c

Browse files
committed
adjust_contrast convert to float32 earlier
1 parent 4117957 commit a82cf8c

File tree

1 file changed

+2
-2
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+2
-2
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,8 +79,8 @@ def adjust_contrast_image_tensor(image: torch.Tensor, contrast_factor: float) ->
7979
if c not in [1, 3]:
8080
raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}")
8181
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
82-
grayscale_image = _rgb_to_gray(image) if c == 3 else image
83-
mean = torch.mean(grayscale_image.to(dtype), dim=(-3, -2, -1), keepdim=True)
82+
grayscale_image = _rgb_to_gray(image.to(dtype)) if c == 3 else image.to(dtype)
83+
mean = torch.mean(grayscale_image, dim=(-3, -2, -1), keepdim=True)
8484
return _blend(image, mean, contrast_factor)
8585

8686

0 commit comments

Comments
 (0)