@@ -44,22 +44,19 @@ def get_image_num_channels(img: Tensor) -> int:
44
44
raise TypeError (f"Input ndim should be 2 or more. Got { img .ndim } " )
45
45
46
46
47
- def _max_value (dtype : torch .dtype ) -> float :
48
- # TODO: replace this method with torch.iinfo when it gets torchscript support.
49
- # https://github.com/pytorch/pytorch/issues/41492
50
-
51
- a = torch .tensor (2 , dtype = dtype )
52
- signed = 1 if torch .tensor (0 , dtype = dtype ).is_signed () else 0
53
- bits = 1
54
- max_value = torch .tensor (- signed , dtype = torch .long )
55
- while True :
56
- next_value = a .pow (bits - signed ).sub (1 )
57
- if next_value > max_value :
58
- max_value = next_value
59
- bits *= 2
60
- else :
61
- break
62
- return max_value .item ()
47
+ def _max_value (dtype : torch .dtype ) -> int :
48
+ if dtype == torch .uint8 :
49
+ return int (2 ** 8 ) - 1
50
+ elif dtype == torch .int8 :
51
+ return int (2 ** 7 ) - 1
52
+ elif dtype == torch .int16 :
53
+ return int (2 ** 15 ) - 1
54
+ elif dtype == torch .int32 :
55
+ return int (2 ** 31 ) - 1
56
+ elif dtype == torch .int64 :
57
+ return int (2 ** 63 ) - 1
58
+ else :
59
+ return 1
63
60
64
61
65
62
def _assert_channels (img : Tensor , permitted : List [int ]) -> None :
@@ -91,19 +88,19 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
91
88
# `max + 1 - epsilon` provides more evenly distributed mapping of
92
89
# ranges of floats to ints.
93
90
eps = 1e-3
94
- max_val = _max_value (dtype )
91
+ max_val = float ( _max_value (dtype ) )
95
92
result = image .mul (max_val + 1.0 - eps )
96
93
return result .to (dtype )
97
94
else :
98
- input_max = _max_value (image .dtype )
95
+ input_max = float ( _max_value (image .dtype ) )
99
96
100
97
# int to float
101
98
# TODO: replace with dtype.is_floating_point when torchscript supports it
102
99
if torch .tensor (0 , dtype = dtype ).is_floating_point ():
103
100
image = image .to (dtype )
104
101
return image / input_max
105
102
106
- output_max = _max_value (dtype )
103
+ output_max = float ( _max_value (dtype ) )
107
104
108
105
# int to int
109
106
if input_max > output_max :
0 commit comments