Skip to content

Commit a03b386

Browse files
committed
Removed _max_value method and added a dictionary
Related to pytorch#5502
1 parent 71d2bb0 commit a03b386

File tree

1 file changed

+21
-22
lines changed

1 file changed

+21
-22
lines changed

torchvision/transforms/functional_tensor.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import warnings
2-
from typing import Optional, Tuple, List
2+
from typing import Dict, Optional, Tuple, List
33

44
import torch
55
from torch import Tensor
@@ -44,24 +44,6 @@ def get_image_num_channels(img: Tensor) -> int:
4444
raise TypeError(f"Input ndim should be 2 or more. Got {img.ndim}")
4545

4646

47-
def _max_value(dtype: torch.dtype) -> float:
48-
# TODO: replace this method with torch.iinfo when it gets torchscript support.
49-
# https://github.com/pytorch/pytorch/issues/41492
50-
51-
a = torch.tensor(2, dtype=dtype)
52-
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
53-
bits = 1
54-
max_value = torch.tensor(-signed, dtype=torch.long)
55-
while True:
56-
next_value = a.pow(bits - signed).sub(1)
57-
if next_value > max_value:
58-
max_value = next_value
59-
bits *= 2
60-
else:
61-
break
62-
return max_value.item()
63-
64-
6547
def _assert_channels(img: Tensor, permitted: List[int]) -> None:
6648
c = get_dimensions(img)[0]
6749
if c not in permitted:
@@ -72,6 +54,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
7254
if image.dtype == dtype:
7355
return image
7456

57+
_max_values: Dict[torch.dtype, float] = {
58+
torch.uint8: float(255),
59+
torch.int8: float(127),
60+
torch.int16: float(32767),
61+
torch.int32: float(2147483647),
62+
torch.int64: float(9223372036854775807),
63+
}
64+
7565
if image.is_floating_point():
7666

7767
# TODO: replace with dtype.is_floating_point when torchscript supports it
@@ -91,19 +81,28 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
9181
# `max + 1 - epsilon` provides more evenly distributed mapping of
9282
# ranges of floats to ints.
9383
eps = 1e-3
94-
max_val = _max_value(dtype)
84+
if dtype not in _max_values:
85+
msg = f"Internal error when casting from {image.dtype} to {dtype}."
86+
raise RuntimeError(msg)
87+
max_val = _max_values[dtype]
9588
result = image.mul(max_val + 1.0 - eps)
9689
return result.to(dtype)
9790
else:
98-
input_max = _max_value(image.dtype)
91+
if image.dtype not in _max_values:
92+
msg = f"Internal error when casting from {image.dtype} to {dtype}."
93+
raise RuntimeError(msg)
94+
input_max = _max_values[image.dtype]
9995

10096
# int to float
10197
# TODO: replace with dtype.is_floating_point when torchscript supports it
10298
if torch.tensor(0, dtype=dtype).is_floating_point():
10399
image = image.to(dtype)
104100
return image / input_max
105101

106-
output_max = _max_value(dtype)
102+
if dtype not in _max_values:
103+
msg = f"Internal error when casting from {image.dtype} to {dtype}."
104+
raise RuntimeError(msg)
105+
output_max = _max_values[dtype]
107106

108107
# int to int
109108
if input_max > output_max:

0 commit comments

Comments
 (0)