|
2 | 2 | from torchvision.prototype import features
|
3 | 3 | from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
|
4 | 4 |
|
5 |
| -from ._meta import _rgb_to_gray, get_dimensions_image_tensor, get_num_channels_image_tensor |
| 5 | +from ._meta import _rgb_to_gray, convert_dtype_image_tensor, get_dimensions_image_tensor, get_num_channels_image_tensor |
6 | 6 |
|
7 | 7 |
|
8 | 8 | def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
|
@@ -257,7 +257,28 @@ def adjust_hue(inpt: features.InputTypeJIT, hue_factor: float) -> features.Input
|
257 | 257 | return adjust_hue_image_pil(inpt, hue_factor=hue_factor)
|
258 | 258 |
|
259 | 259 |
|
260 |
| -adjust_gamma_image_tensor = _FT.adjust_gamma |
| 260 | +def adjust_gamma_image_tensor(image: torch.Tensor, gamma: float, gain: float = 1.0) -> torch.Tensor: |
| 261 | + if not (isinstance(image, torch.Tensor)): |
| 262 | + raise TypeError("Input img should be Tensor image") |
| 263 | + |
| 264 | + if gamma < 0: |
| 265 | + raise ValueError("Gamma should be a non-negative real number") |
| 266 | + |
| 267 | + # The input image is either assumed to be at [0, 1] scale (if float) or is converted to that scale (if integer). |
| 268 | + # Since the gamma is non-negative, the output remains at [0, 1] scale. |
| 269 | + if not torch.is_floating_point(image): |
| 270 | + output = convert_dtype_image_tensor(image, torch.float32).pow_(gamma) |
| 271 | + else: |
| 272 | + output = image.pow(gamma) |
| 273 | + |
| 274 | + if gain != 1.0: |
| 275 | + # The clamp operation is needed only if multiplication is performed. It's only when gain != 1, that the scale |
| 276 | + # of the output can go beyond [0, 1]. |
| 277 | + output = output.mul_(gain).clamp_(0.0, 1.0) |
| 278 | + |
| 279 | + return convert_dtype_image_tensor(output, image.dtype) |
| 280 | + |
| 281 | + |
261 | 282 | adjust_gamma_image_pil = _FP.adjust_gamma
|
262 | 283 |
|
263 | 284 |
|
|
0 commit comments