diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index d11dd3c3b9f..63fa8a28cfe 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -183,6 +183,30 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT: return autocontrast_image_pil(inpt) +def _scale_channel(img_chan: torch.Tensor) -> torch.Tensor: + # TODO: we should expect bincount to always be faster than histc, but this + # isn't always the case. Once + # https://github.com/pytorch/pytorch/issues/53194 is fixed, remove the if + # block and only use bincount. + if img_chan.is_cuda: + hist = torch.histc(img_chan.to(torch.float32), bins=256, min=0, max=255) + else: + hist = torch.bincount(img_chan.view(-1), minlength=256) + + nonzero_hist = hist[hist != 0] + step = torch.div(nonzero_hist[:-1].sum(), 255, rounding_mode="floor") + if step == 0: + return img_chan + + lut = torch.div(torch.cumsum(hist, 0) + torch.div(step, 2, rounding_mode="floor"), step, rounding_mode="floor") + # Doing inplace clamp and converting lut to uint8 improves perfs + lut.clamp_(0, 255) + lut = lut.to(torch.uint8) + lut = torch.nn.functional.pad(lut[:-1], [1, 0]) + + return lut[img_chan.to(torch.int64)] + + def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.dtype != torch.uint8: raise TypeError(f"Only torch.uint8 image tensors are supported, but found {image.dtype}") @@ -194,15 +218,9 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image elif image.ndim == 2: - return _FT._scale_channel(image) + return _scale_channel(image) else: - return torch.stack( - [ - # TODO: when merging transforms v1 and v2, we can inline this function call - _FT._equalize_single_image(single_image) - for single_image in image.view(-1, num_channels, height, width) - ] - ).view(image.shape) + return torch.stack([_scale_channel(x) for x in image.view(-1, height, width)]).view(image.shape) equalize_image_pil = _FP.equalize