Skip to content

Commit 37bc10c

Browse files
NicolasHugdatumbox
authored andcommitted
[fbsync] Fix hardcoded 255 (#6830)
Summary: * fix prototype kernels * fix stable kernels * fix tests * make test more robust * improve invert for signed integers * improve invert * fix posterize * Revert "assume that integer images are [0, 255] in equalize (#6859)" This reverts commit 436ff9a. * fix solarize in AA * fix resize * Revert "fix resize" This reverts commit 5f33f4a. * add comment to float max value Reviewed By: datumbox Differential Revision: D41020539 fbshipit-source-id: 1c618ead36a0ae4d93b4ebf07186fd39bd85d915 Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent b8c8954 commit 37bc10c

File tree

4 files changed

+69
-65
lines changed

4 files changed

+69
-65
lines changed

test/test_functional_tensor.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -790,32 +790,40 @@ def test_solarize2(device, dtype, config, channels):
790790
)
791791

792792

793+
@pytest.mark.parametrize(
794+
("dtype", "threshold"),
795+
[
796+
*[
797+
(dtype, threshold)
798+
for dtype, threshold in itertools.product(
799+
[torch.float32, torch.float16],
800+
[0.0, 0.25, 0.5, 0.75, 1.0],
801+
)
802+
],
803+
*[(torch.uint8, threshold) for threshold in [0, 64, 128, 192, 255]],
804+
*[(torch.int64, threshold) for threshold in [0, 2**32, 2**63 - 1]],
805+
],
806+
)
793807
@pytest.mark.parametrize("device", cpu_and_gpu())
794-
@pytest.mark.parametrize("threshold", [0.0, 0.25, 0.5, 0.75, 1.0])
795-
def test_solarize_threshold1_bound(threshold, device):
796-
img = torch.rand((3, 12, 23)).to(device)
797-
F_t.solarize(img, threshold)
798-
799-
800-
@pytest.mark.parametrize("device", cpu_and_gpu())
801-
@pytest.mark.parametrize("threshold", [1.5])
802-
def test_solarize_threshold1_upper_bound(threshold, device):
803-
img = torch.rand((3, 12, 23)).to(device)
804-
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
805-
F_t.solarize(img, threshold)
806-
807-
808-
@pytest.mark.parametrize("device", cpu_and_gpu())
809-
@pytest.mark.parametrize("threshold", [0, 64, 128, 192, 255])
810-
def test_solarize_threshold2_bound(threshold, device):
811-
img = torch.randint(0, 256, (3, 12, 23)).to(device)
808+
def test_solarize_threshold_within_bound(threshold, dtype, device):
809+
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
810+
img = make_img((3, 12, 23), dtype=dtype, device=device)
812811
F_t.solarize(img, threshold)
813812

814813

814+
@pytest.mark.parametrize(
815+
("dtype", "threshold"),
816+
[
817+
(torch.float32, 1.5),
818+
(torch.float16, 1.5),
819+
(torch.uint8, 260),
820+
(torch.int64, 2**64),
821+
],
822+
)
815823
@pytest.mark.parametrize("device", cpu_and_gpu())
816-
@pytest.mark.parametrize("threshold", [260])
817-
def test_solarize_threshold2_upper_bound(threshold, device):
818-
img = torch.randint(0, 256, (3, 12, 23)).to(device)
824+
def test_solarize_threshold_above_bound(threshold, dtype, device):
825+
make_img = torch.rand if dtype.is_floating_point else partial(torch.randint, 0, torch.iinfo(dtype).max)
826+
img = make_img((3, 12, 23), dtype=dtype, device=device)
819827
with pytest.raises(TypeError, match="Threshold should be less than bound of img."):
820828
F_t.solarize(img, threshold)
821829

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import AutoAugmentPolicy, functional as F, InterpolationMode, Transform
1010
from torchvision.prototype.transforms.functional._meta import get_spatial_size
11+
from torchvision.transforms import functional_tensor as _FT
1112

1213
from ._utils import _isinstance, _setup_fill_arg
1314

@@ -137,7 +138,7 @@ def _apply_image_or_video_transform(
137138
elif transform_id == "Posterize":
138139
return F.posterize(image, bits=int(magnitude))
139140
elif transform_id == "Solarize":
140-
bound = 1.0 if isinstance(image, torch.Tensor) and image.is_floating_point() else 255.0
141+
bound = _FT._max_value(image.dtype) if isinstance(image, torch.Tensor) else 255.0
141142
return F.solarize(image, threshold=bound * magnitude)
142143
elif transform_id == "AutoContrast":
143144
return F.autocontrast(image)

torchvision/prototype/transforms/functional/_color.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,13 @@
22
from torchvision.prototype import features
33
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
44

5-
from ._meta import _rgb_to_gray, convert_dtype_image_tensor
5+
from ._meta import _num_value_bits, _rgb_to_gray, convert_dtype_image_tensor
66

77

88
def _blend(image1: torch.Tensor, image2: torch.Tensor, ratio: float) -> torch.Tensor:
99
ratio = float(ratio)
1010
fp = image1.is_floating_point()
11-
bound = 1.0 if fp else 255.0
11+
bound = _FT._max_value(image1.dtype)
1212
output = image1.mul(ratio).add_(image2, alpha=(1.0 - ratio)).clamp_(0, bound)
1313
return output if fp else output.to(image1.dtype)
1414

@@ -20,7 +20,7 @@ def adjust_brightness_image_tensor(image: torch.Tensor, brightness_factor: float
2020
_FT._assert_channels(image, [1, 3])
2121

2222
fp = image.is_floating_point()
23-
bound = 1.0 if fp else 255.0
23+
bound = _FT._max_value(image.dtype)
2424
output = image.mul(brightness_factor).clamp_(0, bound)
2525
return output if fp else output.to(image.dtype)
2626

@@ -222,19 +222,15 @@ def adjust_hue_image_tensor(image: torch.Tensor, hue_factor: float) -> torch.Ten
222222
return image
223223

224224
orig_dtype = image.dtype
225-
if image.dtype == torch.uint8:
226-
image = image / 255.0
225+
image = convert_dtype_image_tensor(image, torch.float32)
227226

228227
image = _rgb_to_hsv(image)
229228
h, s, v = image.unbind(dim=-3)
230229
h.add_(hue_factor).remainder_(1.0)
231230
image = torch.stack((h, s, v), dim=-3)
232231
image_hue_adj = _hsv_to_rgb(image)
233232

234-
if orig_dtype == torch.uint8:
235-
image_hue_adj = image_hue_adj.mul_(255.0).to(dtype=orig_dtype)
236-
237-
return image_hue_adj
233+
return convert_dtype_image_tensor(image_hue_adj, orig_dtype)
238234

239235

240236
adjust_hue_image_pil = _FP.adjust_hue
@@ -289,14 +285,15 @@ def adjust_gamma(inpt: features.InputTypeJIT, gamma: float, gain: float = 1) ->
289285

290286

291287
def posterize_image_tensor(image: torch.Tensor, bits: int) -> torch.Tensor:
292-
if bits > 8:
293-
return image
294-
295288
if image.is_floating_point():
296289
levels = 1 << bits
297290
return image.mul(levels).floor_().clamp_(0, levels - 1).div_(levels)
298291
else:
299-
mask = ((1 << bits) - 1) << (8 - bits)
292+
num_value_bits = _num_value_bits(image.dtype)
293+
if bits >= num_value_bits:
294+
return image
295+
296+
mask = ((1 << bits) - 1) << (num_value_bits - bits)
300297
return image & mask
301298

302299

@@ -317,8 +314,7 @@ def posterize(inpt: features.InputTypeJIT, bits: int) -> features.InputTypeJIT:
317314

318315

319316
def solarize_image_tensor(image: torch.Tensor, threshold: float) -> torch.Tensor:
320-
bound = 1 if image.is_floating_point() else 255
321-
if threshold > bound:
317+
if threshold > _FT._max_value(image.dtype):
322318
raise TypeError(f"Threshold should be less or equal the maximum value of the dtype, but got {threshold}")
323319

324320
return torch.where(image >= threshold, invert_image_tensor(image), image)
@@ -349,7 +345,7 @@ def autocontrast_image_tensor(image: torch.Tensor) -> torch.Tensor:
349345
# exit earlier on empty images
350346
return image
351347

352-
bound = 1.0 if image.is_floating_point() else 255.0
348+
bound = _FT._max_value(image.dtype)
353349
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
354350

355351
minimum = image.amin(dim=(-2, -1), keepdim=True).to(dtype)
@@ -383,14 +379,18 @@ def equalize_image_tensor(image: torch.Tensor) -> torch.Tensor:
383379
if image.numel() == 0:
384380
return image
385381

382+
# 1. The algorithm below can easily be extended to support arbitrary integer dtypes. However, the histogram that
383+
# would be needed to computed will have at least `torch.iinfo(dtype).max + 1` values. That is perfectly fine for
384+
# `torch.int8`, `torch.uint8`, and `torch.int16`, at least questionable for `torch.int32` and completely
385+
# unfeasible for `torch.int64`.
386+
# 2. Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
387+
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
388+
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it slower
389+
# and more complicated to implement than a simple conversion and a fast histogram implementation for integers.
390+
# Since we need to convert in most cases anyway and out of the acceptable dtypes mentioned in 1. `torch.uint8` is
391+
# by far the most common, we choose it as base.
386392
output_dtype = image.dtype
387-
if image.is_floating_point():
388-
# Floating point inputs need to be binned for this algorithm. Apart from converting them to an integer dtype, we
389-
# could also use PyTorch's builtin histogram functionality. However, that has its own set of issues: in addition
390-
# to being slow in general, PyTorch's implementation also doesn't support batches. In total, that makes it
391-
# slower and more complicated to implement than a simple conversion and a fast histogram implementation for
392-
# integers.
393-
image = convert_dtype_image_tensor(image, torch.uint8)
393+
image = convert_dtype_image_tensor(image, torch.uint8)
394394

395395
# The histogram is computed by using the flattened image as index. For example, a pixel value of 127 in the image
396396
# corresponds to adding 1 to index 127 in the histogram.
@@ -461,10 +461,13 @@ def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
461461

462462

463463
def invert_image_tensor(image: torch.Tensor) -> torch.Tensor:
464-
if image.dtype == torch.uint8:
464+
if image.is_floating_point():
465+
return 1.0 - image # type: ignore[no-any-return]
466+
elif image.dtype == torch.uint8:
465467
return image.bitwise_not()
466-
else:
467-
return (1 if image.is_floating_point() else 255) - image # type: ignore[no-any-return]
468+
else: # signed integer dtypes
469+
# We can't use `Tensor.bitwise_not` here, since we want to retain the leading zero bit that encodes the sign
470+
return image.bitwise_xor((1 << _num_value_bits(image.dtype)) - 1)
468471

469472

470473
invert_image_pil = _FP.invert

torchvision/transforms/functional_tensor.py

Lines changed: 9 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,6 @@ def _assert_image_tensor(img: Tensor) -> None:
1515
raise TypeError("Tensor is not a torch image.")
1616

1717

18-
def _assert_threshold(img: Tensor, threshold: float) -> None:
19-
bound = 1 if img.is_floating_point() else 255
20-
if threshold > bound:
21-
raise TypeError("Threshold should be less than bound of img.")
22-
23-
2418
def get_dimensions(img: Tensor) -> List[int]:
2519
_assert_image_tensor(img)
2620
channels = 1 if img.ndim == 2 else img.shape[-3]
@@ -56,6 +50,8 @@ def _max_value(dtype: torch.dtype) -> int:
5650
elif dtype == torch.int64:
5751
return 9223372036854775807
5852
else:
53+
# This is only here for completeness. This value is implicitly assumed in a lot of places so changing it is not
54+
# easy.
5955
return 1
6056

6157

@@ -212,19 +208,15 @@ def adjust_hue(img: Tensor, hue_factor: float) -> Tensor:
212208
return img
213209

214210
orig_dtype = img.dtype
215-
if img.dtype == torch.uint8:
216-
img = img.to(dtype=torch.float32) / 255.0
211+
img = convert_image_dtype(img, torch.float32)
217212

218213
img = _rgb2hsv(img)
219214
h, s, v = img.unbind(dim=-3)
220215
h = (h + hue_factor) % 1.0
221216
img = torch.stack((h, s, v), dim=-3)
222217
img_hue_adj = _hsv2rgb(img)
223218

224-
if orig_dtype == torch.uint8:
225-
img_hue_adj = (img_hue_adj * 255.0).to(dtype=orig_dtype)
226-
227-
return img_hue_adj
219+
return convert_image_dtype(img_hue_adj, orig_dtype)
228220

229221

230222
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
@@ -263,7 +255,7 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
263255

264256
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
265257
ratio = float(ratio)
266-
bound = 1.0 if img1.is_floating_point() else 255.0
258+
bound = _max_value(img1.dtype)
267259
return (ratio * img1 + (1.0 - ratio) * img2).clamp(0, bound).to(img1.dtype)
268260

269261

@@ -775,8 +767,7 @@ def invert(img: Tensor) -> Tensor:
775767

776768
_assert_channels(img, [1, 3])
777769

778-
bound = torch.tensor(1 if img.is_floating_point() else 255, dtype=img.dtype, device=img.device)
779-
return bound - img
770+
return _max_value(img.dtype) - img
780771

781772

782773
def posterize(img: Tensor, bits: int) -> Tensor:
@@ -802,7 +793,8 @@ def solarize(img: Tensor, threshold: float) -> Tensor:
802793

803794
_assert_channels(img, [1, 3])
804795

805-
_assert_threshold(img, threshold)
796+
if threshold > _max_value(img.dtype):
797+
raise TypeError("Threshold should be less than bound of img.")
806798

807799
inverted_img = invert(img)
808800
return torch.where(img >= threshold, inverted_img, img)
@@ -849,7 +841,7 @@ def autocontrast(img: Tensor) -> Tensor:
849841

850842
_assert_channels(img, [1, 3])
851843

852-
bound = 1.0 if img.is_floating_point() else 255.0
844+
bound = _max_value(img.dtype)
853845
dtype = img.dtype if torch.is_floating_point(img) else torch.float32
854846

855847
minimum = img.amin(dim=(-2, -1), keepdim=True).to(dtype)

0 commit comments

Comments
 (0)