Skip to content

Commit c0236fc

Browse files
committed
Revert "assume that integer images are [0, 255] in equalize (pytorch#6859)"
This reverts commit 436ff9a.
1 parent 6895f71 commit c0236fc

File tree

1 file changed

+11
-7
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+11
-7
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -387,14 +387,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
387387
if image.numel() == 0:
388388
return image
389389

390+
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
391+
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
392+
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
393+
# unfeasible for `torch.int64`.
394+
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
395+
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
396+
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
397+
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
398+
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
399+
# by far the most common, we choose it as base.
390400
output_dtype = image.dtype
391-
if image.is_floating_point():
392-
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
393-
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
394-
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
395-
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
396-
# integers.
397-
image = convert_dtype_image_tensor(image, torch.uint8)
401+
image = convert_dtype_image_tensor(image, torch.uint8)
398402

399403
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
400404
# corresponds to adding 1 to index 127 in the histogram.

0 commit comments

Comments
 (0)