Skip to content

Fix hardcoded 255 #6830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Nov 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
3e81aef
fix prototype kernels
pmeier Oct 24, 2022
33852be
fix stable kernels
pmeier Oct 24, 2022
3a92412
fix tests
pmeier Oct 25, 2022
e13613a
make test more robust
pmeier Oct 25, 2022
a400225
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 25, 2022
e053125
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 25, 2022
3327e04
improve invert for signed integers
pmeier Oct 27, 2022
91e8c66
Merge branch 'main' into fix-hardcoded-255
datumbox Oct 27, 2022
bdd8127
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 28, 2022
c672425
improve invert
pmeier Oct 28, 2022
6375627
fix posterize
pmeier Oct 28, 2022
6895f71
Merge branch 'main' into fix-hardcoded-255
pmeier Oct 28, 2022
c0236fc
Revert "assume that integer images are [0, 255] in equalize (#6859)"
pmeier Oct 28, 2022
8713528
Merge branch 'fix-hardcoded-255' of https://github.com/pmeier/vision …
pmeier Oct 28, 2022
9acf2f4
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 2, 2022
402b01f
fix solarize in AA
pmeier Nov 2, 2022
d0394b7
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
5f33f4a
fix resize
pmeier Nov 3, 2022
3a13a08
Revert "fix resize"
pmeier Nov 3, 2022
7765a47
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
2d0549d
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
f594ceb
Merge branch 'main' into fix-hardcoded-255
pmeier Nov 3, 2022
48603b0
add comment to float max value
pmeier Nov 3, 2022
a61d44f
Merge branch 'fix-hardcoded-255' of https://github.com/pmeier/vision …
pmeier Nov 3, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 29 additions & 21 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -790,32 +790,40 @@ def test_solarize2(device, dtype, config, channels):
)


@pytest.mark.parametrize(
("dtype", "threshold"),
[
*[
(dtype, threshold)
for dtype, threshold in itertools.product(
[torch.float32, torch.float16],
[0.0, 0.25, 0.5, 0.75, 1.0],
)
],
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
def test_solarize_threshold1_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [1.5])
def test_solarize_threshold1_upper_bound(threshold, device):
img = torch.rand((3, 12, 23)).to(device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
def test_solarize_threshold2_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
def test_solarize_threshold_within_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
F_t.solarize(img, threshold)


@pytest.mark.parametrize(
("dtype", "threshold"),
[
(torch.float32, 1.5),
(torch.float16, 1.5),
(torch.uint8, 260),
(torch.int64, 2**64),
],
)
@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("threshold", [260])
def test_solarize_threshold2_upper_bound(threshold, device):
img = torch.randint(0, 256, (3, 12, 23)).to(device)
def test_solarize_threshold_above_bound(threshold, dtype, device):
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
img = make_img((3, 12, 23), dtype=dtype, device=device)
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
F_t.solarize(img, threshold)

Expand Down
3 changes: 2 additions & 1 deletion torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
from torchvision.prototype.transforms.functional._meta import get_spatial_size
from torchvision.transforms import functional_tensor as _FT

from ._utils import _isinstance, _setup_fill_arg

Expand Down Expand Up @@ -137,7 +138,7 @@ def _apply_image_or_video_transform(
elif transform_id == "Posterize":
return F.posterize(image, bits=int(magnitude))
elif transform_id == "Solarize":
bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
return F.solarize(image, threshold=bound * magnitude)
elif transform_id == "AutoContrast":
return F.autocontrast(image)
Expand Down
55 changes: 29 additions & 26 deletions torchvision/prototype/transforms/functional/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@
from torchvision.prototype import features
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT

from ._meta import _rgb_to_gray, convert_dtype_image_tensor
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor


def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
ratio = float(ratio)
fp = image1.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image1.dtype)
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
return output if fp else output.to(image1.dtype)

Expand All @@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
_FT._assert_channels(image, [1, 3])

fp = image.is_floating_point()
bound = 1.0 if fp else 255.0
bound = _FT._max_value(image.dtype)
output = image.mul(brightness_factor).clamp_(0, bound)
return output if fp else output.to(image.dtype)

Expand Down Expand Up @@ -222,19 +222,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
return image

orig_dtype = image.dtype
if image.dtype == torch.uint8:
image = image / 255.0
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of doing the conversion manually, I've opted to use our kernel for this. Note that this also implicitly converts to float32 since the divisor is a float.

image = convert_dtype_image_tensor(image, torch.float32)

image = _rgb_to_hsv(image)
h, s, v = image.unbind(dim=-3)
h.add_(hue_factor).remainder_(1.0)
image = torch.stack((h, s, v), dim=-3)
image_hue_adj = _hsv_to_rgb(image)

if orig_dtype == torch.uint8:
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)

return image_hue_adj
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)


adjust_hue_image_pil = _FP.adjust_hue
Expand Down Expand Up @@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->


def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
if bits > 8:
return image

if image.is_floating_point():
levels = 1 << bits
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
else:
mask = ((1 << bits) - 1) << (8 - bits)
num_value_bits = _num_value_bits(image.dtype)
if bits >= num_value_bits:
return image

mask = ((1 << bits) - 1) << (num_value_bits - bits)
return image & mask


Expand All @@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:


def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
bound = 1 if image.is_floating_point() else 255
if threshold > bound:
if threshold > _FT._max_value(image.dtype):
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")

return torch.where(image >= threshold, invert_image_tensor(image), image)
Expand Down Expand Up @@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
# exit earlier on empty images
return image

bound = 1.0 if image.is_floating_point() else 255.0
bound = _FT._max_value(image.dtype)
dtype = image.dtype if torch.is_floating_point(image) else torch.float32

minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
Expand Down Expand Up @@ -383,14 +379,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.numel() == 0:
return image

# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
# unfeasible for `torch.int64`.
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
# by far the most common, we choose it as base.
output_dtype = image.dtype
if image.is_floating_point():
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
# integers.
image = convert_dtype_image_tensor(image, torch.uint8)
image = convert_dtype_image_tensor(image, torch.uint8)

# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
# corresponds to adding 1 to index 127 in the histogram.
Expand Down Expand Up @@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:


def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
if image.dtype == torch.uint8:
if image.is_floating_point():
return 1.0 - image # type: ignore[no-any-return]
elif image.dtype == torch.uint8:
return image.bitwise_not()
else:
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return]
else: # signed integer dtypes
# We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide benchmarks for this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[--------------------------- invert_image_tensor ---------------------------]
                                      |       main       |  fix-hardcoded-255
1 threads: ------------------------------------------------------------------
      (3, 512, 512), float32, cpu     |   61 (+-  0) us  |     57 (+-  0) us 
      (3, 512, 512), uint8, cpu       |   17 (+-  0) us  |     17 (+-  0) us 
      (3, 512, 512), int32, cpu       |   78 (+-  0) us  |     63 (+-  0) us 
      (5, 3, 512, 512), float32, cpu  |  461 (+- 33) us  |    445 (+- 31) us 
      (5, 3, 512, 512), uint8, cpu    |   98 (+-  1) us  |     79 (+-  1) us 
      (5, 3, 512, 512), int32, cpu    |  538 (+- 67) us  |    514 (+-  8) us 

Times are in microseconds (us).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I like it when bug/code-quality fixing leads to speed improvements. What more can we ask? 😄



invert_image_pil = _FP.invert
Expand Down
26 changes: 9 additions & 17 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None:
raise TypeError("Tensor is not a torch image.")


def _assert_threshold(img: Tensor, threshold: float) -> None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was only used once so I inlined it.

bound = 1 if img.is_floating_point() else 255
if threshold > bound:
raise TypeError("Threshold should be less than bound of img.")


def get_dimensions(img: Tensor) -> List[int]:
_assert_image_tensor(img)
channels = 1 if img.ndim == 2 else img.shape[-3]
Expand Down Expand Up @@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int:
elif dtype == torch.int64:
return 9223372036854775807
else:
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
# easy.
return 1


Expand Down Expand Up @@ -212,19 +208,15 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
return img

orig_dtype = img.dtype
if img.dtype == torch.uint8:
img = img.to(dtype=torch.float32) / 255.0
img = convert_image_dtype(img, torch.float32)

img = _rgb2hsv(img)
h, s, v = img.unbind(dim=-3)
h = (h + hue_factor) % 1.0
img = torch.stack((h, s, v), dim=-3)
img_hue_adj = _hsv2rgb(img)

if orig_dtype == torch.uint8:
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)

return img_hue_adj
return convert_image_dtype(img_hue_adj, orig_dtype)


def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
Expand Down Expand Up @@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:

def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
ratio = float(ratio)
bound = 1.0 if img1.is_floating_point() else 255.0
bound = _max_value(img1.dtype)
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)


Expand Down Expand Up @@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor:

_assert_channels(img, [1, 3])

bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
return bound - img
return _max_value(img.dtype) - img


def posterize(img: Tensor, bits: int) -> Tensor:
Expand All @@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:

_assert_channels(img, [1, 3])

_assert_threshold(img, threshold)
if threshold > _max_value(img.dtype):
raise TypeError("Threshold should be less than bound of img.")

inverted_img = invert(img)
return torch.where(img >= threshold, inverted_img, img)
Expand Down Expand Up @@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor:

_assert_channels(img, [1, 3])

bound = 1.0 if img.is_floating_point() else 255.0
bound = _max_value(img.dtype)
dtype = img.dtype if torch.is_floating_point(img) else torch.float32

minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)
Expand Down