Skip to content

Commit 57c8de7

Browse files
authored
Recoded _max_value method using a dictionary (#5566)
* Removed _max_value method and added a dictionary Related to #5502 * Addressed failing tests and restored _max_value method * Added xfailing test to switch quicker * Switch to if/else impl
1 parent d8654bb commit 57c8de7

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

test/test_transforms.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,6 +1486,15 @@ def test_max_value(dtype):
14861486
# self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
14871487

14881488

1489+
@pytest.mark.xfail(
1490+
reason="torch.iinfo() is not supported by torchscript. See https://github.com/pytorch/pytorch/issues/41492."
1491+
)
1492+
def test_max_value_iinfo():
1493+
@torch.jit.script
1494+
def max_value(image: torch.Tensor) -> int:
1495+
return 1 if image.is_floating_point() else torch.iinfo(image.dtype).max
1496+
1497+
14891498
@pytest.mark.parametrize("should_vflip", [True, False])
14901499
@pytest.mark.parametrize("single_dim", [True, False])
14911500
def test_ten_crop(should_vflip, single_dim):

torchvision/transforms/functional_tensor.py

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -44,22 +44,19 @@ def get_image_num_channels(img: Tensor) -> int:
4444
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
4545

4646

47-
def _max_value(dtype: torch.dtype) -> float:
48-
# TODO: replace this method with torch.iinfo when it gets torchscript support.
49-
# https://github.com/pytorch/pytorch/issues/41492
50-
51-
a = torch.tensor(2, dtype=dtype)
52-
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
53-
bits = 1
54-
max_value = torch.tensor(-signed, dtype=torch.long)
55-
while True:
56-
next_value = a.pow(bits - signed).sub(1)
57-
if next_value > max_value:
58-
max_value = next_value
59-
bits *= 2
60-
else:
61-
break
62-
return max_value.item()
47+
def _max_value(dtype: torch.dtype) -> int:
48+
if dtype == torch.uint8:
49+
return int(2 ** 8) - 1
50+
elif dtype == torch.int8:
51+
return int(2 ** 7) - 1
52+
elif dtype == torch.int16:
53+
return int(2 ** 15) - 1
54+
elif dtype == torch.int32:
55+
return int(2 ** 31) - 1
56+
elif dtype == torch.int64:
57+
return int(2 ** 63) - 1
58+
else:
59+
return 1
6360

6461

6562
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
@@ -91,19 +88,19 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
9188
# `max + 1 - epsilon` provides more evenly distributed mapping of
9289
# ranges of floats to ints.
9390
eps = 1e-3
94-
max_val = _max_value(dtype)
91+
max_val = float(_max_value(dtype))
9592
result = image.mul(max_val + 1.0 - eps)
9693
return result.to(dtype)
9794
else:
98-
input_max = _max_value(image.dtype)
95+
input_max = float(_max_value(image.dtype))
9996

10097
# int to float
10198
# TODO: replace with dtype.is_floating_point when torchscript supports it
10299
if torch.tensor(0, dtype=dtype).is_floating_point():
103100
image = image.to(dtype)
104101
return image / input_max
105102

106-
output_max = _max_value(dtype)
103+
output_max = float(_max_value(dtype))
107104

108105
# int to int
109106
if input_max > output_max:

0 commit comments

Comments
 (0)