1
1
import warnings
2
- from typing import Optional , Tuple , List
2
+ from typing import Dict , Optional , Tuple , List
3
3
4
4
import torch
5
5
from torch import Tensor
@@ -44,24 +44,6 @@ def get_image_num_channels(img: Tensor) -> int:
44
44
raise TypeError (f"Input ndim should be 2 or more. Got { img .ndim } " )
45
45
46
46
47
- def _max_value (dtype : torch .dtype ) -> float :
48
- # TODO: replace this method with torch.iinfo when it gets torchscript support.
49
- # https://github.com/pytorch/pytorch/issues/41492
50
-
51
- a = torch .tensor (2 , dtype = dtype )
52
- signed = 1 if torch .tensor (0 , dtype = dtype ).is_signed () else 0
53
- bits = 1
54
- max_value = torch .tensor (- signed , dtype = torch .long )
55
- while True :
56
- next_value = a .pow (bits - signed ).sub (1 )
57
- if next_value > max_value :
58
- max_value = next_value
59
- bits *= 2
60
- else :
61
- break
62
- return max_value .item ()
63
-
64
-
65
47
def _assert_channels (img : Tensor , permitted : List [int ]) -> None :
66
48
c = get_dimensions (img )[0 ]
67
49
if c not in permitted :
@@ -72,6 +54,14 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
72
54
if image .dtype == dtype :
73
55
return image
74
56
57
+ _max_values : Dict [torch .dtype , float ] = {
58
+ torch .uint8 : float (255 ),
59
+ torch .int8 : float (127 ),
60
+ torch .int16 : float (32767 ),
61
+ torch .int32 : float (2147483647 ),
62
+ torch .int64 : float (9223372036854775807 ),
63
+ }
64
+
75
65
if image .is_floating_point ():
76
66
77
67
# TODO: replace with dtype.is_floating_point when torchscript supports it
@@ -91,19 +81,28 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
91
81
# `max + 1 - epsilon` provides more evenly distributed mapping of
92
82
# ranges of floats to ints.
93
83
eps = 1e-3
94
- max_val = _max_value (dtype )
84
+ if dtype not in _max_values :
85
+ msg = f"Internal error when casting from { image .dtype } to { dtype } ."
86
+ raise RuntimeError (msg )
87
+ max_val = _max_values [dtype ]
95
88
result = image .mul (max_val + 1.0 - eps )
96
89
return result .to (dtype )
97
90
else :
98
- input_max = _max_value (image .dtype )
91
+ if image .dtype not in _max_values :
92
+ msg = f"Internal error when casting from { image .dtype } to { dtype } ."
93
+ raise RuntimeError (msg )
94
+ input_max = _max_values [image .dtype ]
99
95
100
96
# int to float
101
97
# TODO: replace with dtype.is_floating_point when torchscript supports it
102
98
if torch .tensor (0 , dtype = dtype ).is_floating_point ():
103
99
image = image .to (dtype )
104
100
return image / input_max
105
101
106
- output_max = _max_value (dtype )
102
+ if dtype not in _max_values :
103
+ msg = f"Internal error when casting from { image .dtype } to { dtype } ."
104
+ raise RuntimeError (msg )
105
+ output_max = _max_values [dtype ]
107
106
108
107
# int to int
109
108
if input_max > output_max :
0 commit comments