diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index 7bf412aaf99..2c268fa4085 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -143,7 +143,104 @@ def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> fe return adjust_sharpness_image_pil(inpt, sharpness_factor=sharpness_factor) -adjust_hue_image_tensor = _FT.adjust_hue +def _rgb_to_hsv(image: torch.Tensor) -> torch.Tensor: + r, g, _ = image.unbind(dim=-3) + + # Implementation is based on + # https://github.com/python-pillow/Pillow/blob/4174d4267616897df3746d315d5a2d0f82c656ee/src/libImaging/Convert.c#L330 + minc, maxc = torch.aminmax(image, dim=-3) + + # The algorithm erases S and H channel where `maxc = minc`. This avoids NaN + # from happening in the results, because + # + S channel has division by `maxc`, which is zero only if `maxc = minc` + # + H channel has division by `(maxc - minc)`. + # + # Instead of overwriting NaN afterwards, we just prevent it from occuring so + # we don't need to deal with it in case we save the NaN in a buffer in + # backprop, if it is ever supported, but it doesn't hurt to do so. + eqc = maxc == minc + + channels_range = maxc - minc + # Since `eqc => channels_range = 0`, replacing denominator with 1 when `eqc` is fine. + ones = torch.ones_like(maxc) + s = channels_range / torch.where(eqc, ones, maxc) + # Note that `eqc => maxc = minc = r = g = b`. So the following calculation + # of `h` would reduce to `bc - gc + 2 + rc - bc + 4 + rc - bc = 6` so it + # would not matter what values `rc`, `gc`, and `bc` have here, and thus + # replacing denominator with 1 when `eqc` is fine. + channels_range_divisor = torch.where(eqc, ones, channels_range).unsqueeze_(dim=-3) + rc, gc, bc = ((maxc.unsqueeze(dim=-3) - image) / channels_range_divisor).unbind(dim=-3) + + mask_maxc_neq_r = maxc != r + mask_maxc_eq_g = maxc == g + mask_maxc_neq_g = ~mask_maxc_eq_g + + hr = (bc - gc).mul_(~mask_maxc_neq_r) + hg = (2.0 + rc).sub_(bc).mul_(mask_maxc_eq_g & mask_maxc_neq_r) + hb = (4.0 + gc).sub_(rc).mul_(mask_maxc_neq_g & mask_maxc_neq_r) + + h = hr.add_(hg).add_(hb) + h = h.div_(6.0).add_(1.0).fmod_(1.0) + return torch.stack((h, s, maxc), dim=-3) + + +def _hsv_to_rgb(img: torch.Tensor) -> torch.Tensor: + h, s, v = img.unbind(dim=-3) + h6 = h * 6 + i = torch.floor(h6) + f = (h6) - i + i = i.to(dtype=torch.int32) + + p = (v * (1.0 - s)).clamp_(0.0, 1.0) + q = (v * (1.0 - s * f)).clamp_(0.0, 1.0) + t = (v * (1.0 - s * (1.0 - f))).clamp_(0.0, 1.0) + i.remainder_(6) + + mask = i.unsqueeze(dim=-3) == torch.arange(6, device=i.device).view(-1, 1, 1) + + a1 = torch.stack((v, q, p, p, t, v), dim=-3) + a2 = torch.stack((t, v, v, q, p, p), dim=-3) + a3 = torch.stack((p, p, t, v, v, q), dim=-3) + a4 = torch.stack((a1, a2, a3), dim=-4) + + return (a4.mul_(mask.to(dtype=img.dtype).unsqueeze(dim=-4))).sum(dim=-3) + + +def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Tensor: + if not (-0.5 <= hue_factor <= 0.5): + raise ValueError(f"hue_factor ({hue_factor}) is not in [-0.5, 0.5].") + + if not (isinstance(image, torch.Tensor)): + raise TypeError("Input img should be Tensor image") + + c = get_num_channels_image_tensor(image) + + if c not in [1, 3]: + raise TypeError(f"Input image tensor permitted channel values are {[1, 3]}, but found {c}") + + if c == 1: # Match PIL behaviour + return image + + if image.numel() == 0: + # exit earlier on empty images + return image + + orig_dtype = image.dtype + if image.dtype == torch.uint8: + image = image / 255.0 + + image = _rgb_to_hsv(image) + h, s, v = image.unbind(dim=-3) + h.add_(hue_factor).remainder_(1.0) + image = torch.stack((h, s, v), dim=-3) + image_hue_adj = _hsv_to_rgb(image) + + if orig_dtype == torch.uint8: + image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype) + + return image_hue_adj + + adjust_hue_image_pil = _FP.adjust_hue