diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index 7cfb9b6a785..287b1acaa27 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -1446,16 +1446,14 @@ def sample_inputs_invert_video(): def sample_inputs_posterize_image_tensor(): for image_loader in make_image_loaders( - sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8] + sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB) ): yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0]) def reference_inputs_posterize_image_tensor(): for image_loader, bits in itertools.product( - make_image_loaders( - color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8] - ), + make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]), _POSTERIZE_BITS, ): yield ArgsKwargs(image_loader, bits=bits) diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index fb238510242..e7977156d2c 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -289,7 +289,18 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain) -posterize_image_tensor = _FT.posterize +def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: + if bits > 8: + return image + + if image.is_floating_point(): + levels = 1 << bits + return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) + else: + mask = ((1 << bits) - 1) << (8 - bits) + return image & mask + + posterize_image_pil = _FP.posterize