diff --git a/torchvision/prototype/transforms/functional/_color.py b/torchvision/prototype/transforms/functional/_color.py index e7977156d2c..67a55cfb1d7 100644 --- a/torchvision/prototype/transforms/functional/_color.py +++ b/torchvision/prototype/transforms/functional/_color.py @@ -385,18 +385,14 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: if image.numel() == 0: return image - # 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that - # would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for - # `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely - # unfeasible for `torch.int64`. - # 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we - # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition - # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower - # and more complicated to implement than a simple conversion and a fast histogram implementation for integers. - # Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is - # by far the most common, we choose it as base. output_dtype = image.dtype - image = convert_dtype_image_tensor(image, torch.uint8) + if image.is_floating_point(): + # Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we + # could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition + # to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it + # slower and more complicated to implement than a simple conversion and a fast histogram implementation for + # integers. + image = convert_dtype_image_tensor(image, torch.uint8) # The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image # corresponds to adding 1 to index 127 in the histogram.