Skip to content

Commit 6979888

Browse files
authored
[prototype] Speed improvement for adjust gamma op (#6820)
* Speed improvement for adjust gamma op * Adding comments and optimizations. * fixing typo * Remove unnecessary channel check.
1 parent 62da7d4 commit 6979888

File tree

1 file changed

+23
-2
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+23
-2
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from torchvision.prototype import features
33
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
44

5-
from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor
5+
from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor
66

77

88
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
@@ -257,7 +257,28 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
257257
return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
258258

259259

260-
adjust_gamma_image_tensor = _FT.adjust_gamma
260+
def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor:
261+
if not (isinstance(image, torch.Tensor)):
262+
raise TypeError("Input img should be Tensor image")
263+
264+
if gamma < 0:
265+
raise ValueError("Gamma should be a non-negative real number")
266+
267+
# The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer).
268+
# Since the gamma is non-negative, the output remains at [0, 1] scale.
269+
if not torch.is_floating_point(image):
270+
output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma)
271+
else:
272+
output = image.pow(gamma)
273+
274+
if gain != 1.0:
275+
# The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale
276+
# of the output can go beyond [0, 1].
277+
output = output.mul_(gain).clamp_(0.0, 1.0)
278+
279+
return convert_dtype_image_tensor(output, image.dtype)
280+
281+
261282
adjust_gamma_image_pil = _FP.adjust_gamma
262283

263284

0 commit comments

Comments
 (0)