Skip to content

extend support of posterize to all integer and floating dtypes #6847

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1446,16 +1446,14 @@ def sample_inputs_invert_video():

def sample_inputs_posterize_image_tensor():
for image_loader in make_image_loaders(
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
):
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])


def reference_inputs_posterize_image_tensor():
for image_loader, bits in itertools.product(
make_image_loaders(
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
),
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
_POSTERIZE_BITS,
):
yield ArgsKwargs(image_loader, bits=bits)
Expand Down
13 changes: 12 additions & 1 deletion torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,18 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)


posterize_image_tensor = _FT.posterize
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if bits > 8:
return image

if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
Copy link
Contributor

@datumbox datumbox Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain the clamp at levels-1? This is the kind of implementation reference I had in mind. Which reference did you use? Also why are we multiplying by 2^bits instead of 2^bits-1 which is supposed to be the max for the specific type?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I touched on this in 3. in my top comment. Since the input range for float images is inclusive on the edges, i.e. [0.0, 1.0], image.mul(levels).floor_() gives us levels + 1 values, i.e. {0.0, 1.0, 2.0, ..., levels - 1.0, levels}.

However, we want the kernel to quantize to levels levels. Thus, we need to remove one level. For integer dtypes, the higher values are removed, i.e. the remaining values are {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}. For example

>>> bits = 3
>>> bit_depth = 8
>>> {i * 2 ** (bit_depth - bits) for i in range(2 ** bits)}
{0, 32, 64, 96, 128, 160, 192, 224}

As you can see the 255 / 256 corresponding to 1.0 in floating point images is missing. Thus, we also clamp that away for floating images.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I read your PR description but I still had questions on this. Why not do image.mul(levels-1) to begin with? Multiplying by levels means that upper bound of 1 will go outside of the permitted range of the type. What am I missing here?

Copy link
Collaborator Author

@pmeier pmeier Oct 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this more concrete, let's look at an example:

>>> image = torch.arange(0, 256, dtype=torch.uint8)
>>> bits = 3
>>> output_baseline = image & (2 ** 8 - 2 ** (8 - bits))
>>> torch.unique(output_baseline)
tensor([  0,  32,  64,  96, 128, 160, 192, 224], dtype=torch.uint8)
>>> image = torch.linspace(0, 1, 100)
>>> output1 = image.mul(2 ** bits).floor_().clamp_(0, 2**bits - 1).div_(2 ** bits)
>>> torch.unique(output1.mul(255).byte())
tensor([  0,  31,  63,  95, 127, 159, 191, 223], dtype=torch.uint8)
>>> torch.unique(output1.mul(255))
tensor([  0.0000,  31.8750,  63.7500,  95.6250, 127.5000, 159.3750, 191.2500,
        223.1250])

The proposal in this is not perfect, but the .byte() call above eliminates some nuances.

In contrast if we do what you propose we get

>>> output2 = image.mul(2 ** bits - 1).floor_().div(2 ** bits - 1)
>>> torch.unique(output2.mul(255).byte())
tensor([  0,  36,  72, 109, 145, 182, 218, 255], dtype=torch.uint8)

This is of course also a valid way to posterize an image to a bit depth of 3, but the behavior is divergent from what we and PIL do for integers.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @pmeier. This makes sense. @vfdev-5 thoughts?

Copy link
Collaborator

@vfdev-5 vfdev-5 Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@pmeier I think you can skip clamp with the following:

(x * (2 ** bits - 1)).floor() / (2 ** bits)

EDIT: for bits=1 above method gets something unexpected:

x = torch.linspace(0.0, 1.0, steps=20)
bits = 1
(x * (2 ** bits - 1)).floor() / (2 ** bits)

# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.5000])
# vs 
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])

EDIT2: a better quantization formula skipping clamp

(x * (2 ** bits - 0.5)).floor() / (2 ** bits)

bits=1
# tensor([0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
        0.5000, 0.5000])

bits=3
# tensor([0.0000, 0.0000, 0.0000, 0.1250, 0.1250, 0.2500, 0.2500, 0.2500, 0.3750,
        0.3750, 0.5000, 0.5000, 0.6250, 0.6250, 0.6250, 0.7500, 0.7500, 0.8750,
        0.8750, 0.8750])

Copy link
Collaborator Author

@pmeier pmeier Oct 28, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately, this does not work. While the individual level values look good

>>> output3 = (image * (2 ** bits - 1)).floor() / (2 ** bits)
>>> torch.unique(output3.mul(255).byte())
tensor([  0,  31,  63,  95, 127, 159, 191, 223], dtype=torch.uint8)

the posterized values do not match what we do for the integers:

>>> _ = torch.manual_seed(0)
>>> bits = 3
>>> image_uint8 = torch.randint(0, 256, (3, 3), dtype=torch.uint8)
>>> posterized_image_uint8 = image_uint8 & (2 ** 8 - 2 **(8 - bits))
>>> posterized_image_uint8
tensor([[160,  32,  96],
        [192,  64, 224],
        [192,  96,   0]], dtype=torch.uint8)
>>> image_float32 = F.convert_dtype_image_tensor(image_uint8, torch.float32)
>>> posterized_image_float32 = (image_float32 * (2 ** bits - 1)).floor() / (2 ** bits)
>>> posterized_image_float32 = posterized_image_float32.mul(255).byte()
>>> posterized_image_float32
tensor([[127,  31,  95],
        [159,  31, 191],
        [159,  63,   0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[33,  1,  1],
        [33, 33, 33],
        [33, 33,  0]], dtype=torch.int32)

This comes from the asymmetry of the multiplication and division and is also what you observed for bits=1 above.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Re EDIT 2: formula still produces different values:

>>> posterized_image_float32 = (image_float32 * (2 ** bits - 0.5)).floor() / (2 ** bits)
>>> posterized_image_float32.mul(255).byte()
tensor([[159,  31,  95],
        [159,  31, 223],
        [159,  95,   0]], dtype=torch.uint8)
>>> posterized_image_uint8.int() - posterized_image_float32.int()
tensor([[ 1,  1,  1],
        [33, 33,  1],
        [33,  1,  0]], dtype=torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, using clamp you still have a difference of +/- 1 but OK, let's have clamp. Probably, it is not a big deal in terms of runtime perfs.

else:
mask = ((1 << bits) - 1) << (8 - bits)
return image & mask


posterize_image_pil = _FP.posterize


Expand Down