-
Notifications
You must be signed in to change notification settings - Fork 7.1k
Fix hardcoded 255 #6830
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix hardcoded 255 #6830
Changes from all commits
3e81aef
33852be
3a92412
e13613a
a400225
e053125
3327e04
91e8c66
bdd8127
c672425
6375627
6895f71
c0236fc
8713528
9acf2f4
402b01f
d0394b7
5f33f4a
3a13a08
7765a47
2d0549d
f594ceb
48603b0
a61d44f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,13 +2,13 @@ | |
from torchvision.prototype import features | ||
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT | ||
|
||
from ._meta import _rgb_to_gray, convert_dtype_image_tensor | ||
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor | ||
|
||
|
||
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor: | ||
ratio = float(ratio) | ||
fp = image1.is_floating_point() | ||
bound = 1.0 if fp else 255.0 | ||
bound = _FT._max_value(image1.dtype) | ||
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound) | ||
return output if fp else output.to(image1.dtype) | ||
|
||
|
@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float | |
_FT._assert_channels(image, [1, 3]) | ||
|
||
fp = image.is_floating_point() | ||
bound = 1.0 if fp else 255.0 | ||
bound = _FT._max_value(image.dtype) | ||
output = image.mul(brightness_factor).clamp_(0, bound) | ||
return output if fp else output.to(image.dtype) | ||
|
||
|
@@ -222,19 +222,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten | |
return image | ||
|
||
orig_dtype = image.dtype | ||
if image.dtype == torch.uint8: | ||
image = image / 255.0 | ||
image = convert_dtype_image_tensor(image, torch.float32) | ||
|
||
image = _rgb_to_hsv(image) | ||
h, s, v = image.unbind(dim=-3) | ||
h.add_(hue_factor).remainder_(1.0) | ||
image = torch.stack((h, s, v), dim=-3) | ||
image_hue_adj = _hsv_to_rgb(image) | ||
|
||
if orig_dtype == torch.uint8: | ||
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype) | ||
|
||
return image_hue_adj | ||
return convert_dtype_image_tensor(image_hue_adj, orig_dtype) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
adjust_hue_image_pil = _FP.adjust_hue | ||
|
@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) -> | |
|
||
|
||
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor: | ||
if bits > 8: | ||
return image | ||
|
||
if image.is_floating_point(): | ||
levels = 1 << bits | ||
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels) | ||
else: | ||
mask = ((1 << bits) - 1) << (8 - bits) | ||
num_value_bits = _num_value_bits(image.dtype) | ||
if bits >= num_value_bits: | ||
return image | ||
|
||
mask = ((1 << bits) - 1) << (num_value_bits - bits) | ||
return image & mask | ||
|
||
|
||
|
@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT: | |
|
||
|
||
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor: | ||
bound = 1 if image.is_floating_point() else 255 | ||
if threshold > bound: | ||
if threshold > _FT._max_value(image.dtype): | ||
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}") | ||
|
||
return torch.where(image >= threshold, invert_image_tensor(image), image) | ||
|
@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor: | |
# exit earlier on empty images | ||
return image | ||
|
||
bound = 1.0 if image.is_floating_point() else 255.0 | ||
bound = _FT._max_value(image.dtype) | ||
dtype = image.dtype if torch.is_floating_point(image) else torch.float32 | ||
|
||
minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype) | ||
|
@@ -383,14 +379,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor: | |
if image.numel() == 0: | ||
return image | ||
|
||
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that | ||
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for | ||
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely | ||
# unfeasible for `torch.int64`. | ||
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we | ||
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition | ||
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower | ||
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers. | ||
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is | ||
# by far the most common, we choose it as base. | ||
output_dtype = image.dtype | ||
if image.is_floating_point(): | ||
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we | ||
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition | ||
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it | ||
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for | ||
# integers. | ||
image = convert_dtype_image_tensor(image, torch.uint8) | ||
image = convert_dtype_image_tensor(image, torch.uint8) | ||
|
||
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image | ||
# corresponds to adding 1 to index 127 in the histogram. | ||
|
@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT: | |
|
||
|
||
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor: | ||
if image.dtype == torch.uint8: | ||
if image.is_floating_point(): | ||
return 1.0 - image # type: ignore[no-any-return] | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
elif image.dtype == torch.uint8: | ||
return image.bitwise_not() | ||
else: | ||
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return] | ||
else: # signed integer dtypes | ||
# We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign | ||
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you provide benchmarks for this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! I like it when bug/code-quality fixing leads to speed improvements. What more can we ask? 😄 |
||
|
||
|
||
invert_image_pil = _FP.invert | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None: | |
raise TypeError("Tensor is not a torch image.") | ||
|
||
|
||
def _assert_threshold(img: Tensor, threshold: float) -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was only used once so I inlined it. |
||
bound = 1 if img.is_floating_point() else 255 | ||
if threshold > bound: | ||
raise TypeError("Threshold should be less than bound of img.") | ||
|
||
|
||
def get_dimensions(img: Tensor) -> List[int]: | ||
_assert_image_tensor(img) | ||
channels = 1 if img.ndim == 2 else img.shape[-3] | ||
|
@@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int: | |
elif dtype == torch.int64: | ||
return 9223372036854775807 | ||
else: | ||
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not | ||
# easy. | ||
return 1 | ||
|
||
|
||
|
@@ -212,19 +208,15 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor: | |
return img | ||
|
||
orig_dtype = img.dtype | ||
if img.dtype == torch.uint8: | ||
img = img.to(dtype=torch.float32) / 255.0 | ||
img = convert_image_dtype(img, torch.float32) | ||
|
||
img = _rgb2hsv(img) | ||
h, s, v = img.unbind(dim=-3) | ||
h = (h + hue_factor) % 1.0 | ||
img = torch.stack((h, s, v), dim=-3) | ||
img_hue_adj = _hsv2rgb(img) | ||
|
||
if orig_dtype == torch.uint8: | ||
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype) | ||
|
||
return img_hue_adj | ||
return convert_image_dtype(img_hue_adj, orig_dtype) | ||
|
||
|
||
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor: | ||
|
@@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor: | |
|
||
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor: | ||
ratio = float(ratio) | ||
bound = 1.0 if img1.is_floating_point() else 255.0 | ||
bound = _max_value(img1.dtype) | ||
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype) | ||
|
||
|
||
|
@@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor: | |
|
||
_assert_channels(img, [1, 3]) | ||
|
||
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device) | ||
return bound - img | ||
return _max_value(img.dtype) - img | ||
|
||
|
||
def posterize(img: Tensor, bits: int) -> Tensor: | ||
|
@@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor: | |
|
||
_assert_channels(img, [1, 3]) | ||
|
||
_assert_threshold(img, threshold) | ||
if threshold > _max_value(img.dtype): | ||
raise TypeError("Threshold should be less than bound of img.") | ||
|
||
inverted_img = invert(img) | ||
return torch.where(img >= threshold, inverted_img, img) | ||
|
@@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor: | |
|
||
_assert_channels(img, [1, 3]) | ||
|
||
bound = 1.0 if img.is_floating_point() else 255.0 | ||
bound = _max_value(img.dtype) | ||
dtype = img.dtype if torch.is_floating_point(img) else torch.float32 | ||
|
||
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of doing the conversion manually, I've opted to use our kernel for this. Note that this also implicitly converts to
float32
since the divisor is afloat
.