Skip to content

Commit 7bf6314

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] extend support of posterize to all integer and floating dtypes (#6847)
Summary: * extend support of posterize to all integer and floating dtypes * remove raise * revert to fixed value range for integer dtypes Reviewed By: datumbox Differential Revision: D40851028 fbshipit-source-id: ebb0460ce9eb414515701303688b16a10dab0dee
1 parent 88584df commit 7bf6314

File tree

2 files changed

+14
-5
lines changed

2 files changed

+14
-5
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1446,16 +1446,14 @@ def sample_inputs_invert_video():
14461446

14471447
def sample_inputs_posterize_image_tensor():
14481448
for image_loader in make_image_loaders(
1449-
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), dtypes=[torch.uint8]
1449+
sizes=["random"], color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB)
14501450
):
14511451
yield ArgsKwargs(image_loader, bits=_POSTERIZE_BITS[0])
14521452

14531453

14541454
def reference_inputs_posterize_image_tensor():
14551455
for image_loader, bits in itertools.product(
1456-
make_image_loaders(
1457-
color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()], dtypes=[torch.uint8]
1458-
),
1456+
make_image_loaders(color_spaces=(features.ColorSpace.GRAY, features.ColorSpace.RGB), extra_dims=[()]),
14591457
_POSTERIZE_BITS,
14601458
):
14611459
yield ArgsKwargs(image_loader, bits=bits)

torchvision/prototype/transforms/functional/_color.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,18 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
289289
return adjust_gamma_image_pil(inpt, gamma=gamma, gain=gain)
290290

291291

292-
posterize_image_tensor = _FT.posterize
292+
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
293+
if bits > 8:
294+
return image
295+
296+
if image.is_floating_point():
297+
levels = 1 << bits
298+
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
299+
else:
300+
mask = ((1 << bits) - 1) << (8 - bits)
301+
return image & mask
302+
303+
293304
posterize_image_pil = _FP.posterize
294305

295306

0 commit comments

Comments
 (0)