diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index a9fd3bc5ec9..9f6817bb60d 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4608,6 +4608,14 @@ def test_correctness_image(self, bits, fn): assert_equal(actual, expected) + @pytest.mark.parametrize("bits", [-1, 9, 2.1]) + def test_error_functional(self, bits): + with pytest.raises( + TypeError, + match=re.escape(f"bits must be a positive integer in the range [0, 8], got {bits} instead."), + ): + F.posterize(make_image(dtype=torch.uint8), bits=bits) + class TestSolarize: def _make_threshold(self, input, *, factor=0.5): diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index eb75f58cb7a..a3f187f84cf 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -460,6 +460,9 @@ def posterize(inpt: torch.Tensor, bits: int) -> torch.Tensor: @_register_kernel_internal(posterize, torch.Tensor) @_register_kernel_internal(posterize, tv_tensors.Image) def posterize_image(image: torch.Tensor, bits: int) -> torch.Tensor: + if not isinstance(bits, int) or not 0 <= bits <= 8: + raise TypeError(f"bits must be a positive integer in the range [0, 8], got {bits} instead.") + if image.is_floating_point(): levels = 1 << bits return image.mul(levels).floor_().clamp_(0, levels - 1).mul_(1.0 / levels)