Skip to content

Commit 60a5972

Browse files
nairbvvfdev-5
andcommitted
make convert_image_dtype scriptable (pytorch#2485)
* make convert_image_dtype scriptable * move convert dtype to functional_tensor since only works on tensors * retain availability of convert_image_dtype in functional.py * Update code and tests * Replaced int by torch.dtype * int -> torch.dtype and use F instead of F_t * Update functional_tensor.py Co-authored-by: vfdev-5 <[email protected]>
1 parent f8861ec commit 60a5972

File tree

5 files changed

+141
-49
lines changed

5 files changed

+141
-49
lines changed

test/test_functional_tensor.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,10 @@ def test_perspective(self):
744744
batch_tensors, F.perspective, startpoints=spoints, endpoints=epoints, interpolation=0
745745
)
746746

747+
def test_convert_image_dtype(self):
748+
# TODO: add tests of CPU/CUDA on tensor and batch
749+
pass
750+
747751

748752
@unittest.skipIf(not torch.cuda.is_available(), reason="Skip if no CUDA device")
749753
class CUDATester(Tester):

test/test_transforms.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import torchvision.transforms as transforms
44
import torchvision.transforms.functional as F
5+
import torchvision.transforms.functional_tensor as F_t
56
from torch._utils_internal import get_file_path_2
67
from numpy.testing import assert_array_almost_equal
78
import unittest
@@ -544,13 +545,26 @@ def test_to_tensor(self):
544545
output = trans(img)
545546
self.assertTrue(np.allclose(input_data.numpy(), output.numpy()))
546547

548+
def test_max_value(self):
549+
for dtype in int_dtypes():
550+
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)
551+
552+
for dtype in float_dtypes():
553+
self.assertGreater(F_t._max_value(dtype), torch.finfo(dtype).max)
554+
547555
def test_convert_image_dtype_float_to_float(self):
548556
for input_dtype, output_dtypes in cycle_over(float_dtypes()):
549557
input_image = torch.tensor((0.0, 1.0), dtype=input_dtype)
550558
for output_dtype in output_dtypes:
551559
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
552560
transform = transforms.ConvertImageDtype(output_dtype)
561+
transform_script = torch.jit.script(F.convert_image_dtype)
562+
553563
output_image = transform(input_image)
564+
output_image_script = transform_script(input_image, output_dtype)
565+
566+
script_diff = output_image_script - output_image
567+
self.assertLess(script_diff.abs().max(), 1e-6)
554568

555569
actual_min, actual_max = output_image.tolist()
556570
desired_min, desired_max = 0.0, 1.0
@@ -564,6 +578,7 @@ def test_convert_image_dtype_float_to_int(self):
564578
for output_dtype in int_dtypes():
565579
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
566580
transform = transforms.ConvertImageDtype(output_dtype)
581+
transform_script = torch.jit.script(F.convert_image_dtype)
567582

568583
if (input_dtype == torch.float32 and output_dtype in (torch.int32, torch.int64)) or (
569584
input_dtype == torch.float64 and output_dtype == torch.int64
@@ -572,6 +587,10 @@ def test_convert_image_dtype_float_to_int(self):
572587
transform(input_image)
573588
else:
574589
output_image = transform(input_image)
590+
output_image_script = transform_script(input_image, output_dtype)
591+
592+
script_diff = output_image_script - output_image
593+
self.assertLess(script_diff.abs().max(), 1e-6)
575594

576595
actual_min, actual_max = output_image.tolist()
577596
desired_min, desired_max = 0, torch.iinfo(output_dtype).max
@@ -585,7 +604,13 @@ def test_convert_image_dtype_int_to_float(self):
585604
for output_dtype in float_dtypes():
586605
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
587606
transform = transforms.ConvertImageDtype(output_dtype)
607+
transform_script = torch.jit.script(F.convert_image_dtype)
608+
588609
output_image = transform(input_image)
610+
output_image_script = transform_script(input_image, output_dtype)
611+
612+
script_diff = output_image_script - output_image
613+
self.assertLess(script_diff.abs().max(), 1e-6)
589614

590615
actual_min, actual_max = output_image.tolist()
591616
desired_min, desired_max = 0.0, 1.0
@@ -604,7 +629,15 @@ def test_convert_image_dtype_int_to_int(self):
604629

605630
with self.subTest(input_dtype=input_dtype, output_dtype=output_dtype):
606631
transform = transforms.ConvertImageDtype(output_dtype)
632+
transform_script = torch.jit.script(F.convert_image_dtype)
633+
607634
output_image = transform(input_image)
635+
output_image_script = transform_script(input_image, output_dtype)
636+
637+
script_diff = output_image_script.float() - output_image.float()
638+
self.assertLess(
639+
script_diff.abs().max(), 1e-6, msg="{} vs {}".format(output_image_script, output_image)
640+
)
608641

609642
actual_min, actual_max = output_image.tolist()
610643
desired_min, desired_max = 0, output_max

torchvision/transforms/functional.py

Lines changed: 4 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -152,48 +152,10 @@ def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -
152152
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
153153
of the integer ``dtype``.
154154
"""
155-
if image.dtype == dtype:
156-
return image
157-
158-
if image.dtype.is_floating_point:
159-
# float to float
160-
if dtype.is_floating_point:
161-
return image.to(dtype)
162-
163-
# float to int
164-
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
165-
image.dtype == torch.float64 and dtype == torch.int64
166-
):
167-
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
168-
raise RuntimeError(msg)
169-
170-
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
171-
# For data in the range 0-1, (float * 255).to(uint) is only 255
172-
# when float is exactly 1.0.
173-
# `max + 1 - epsilon` provides more evenly distributed mapping of
174-
# ranges of floats to ints.
175-
eps = 1e-3
176-
result = image.mul(torch.iinfo(dtype).max + 1 - eps)
177-
return result.to(dtype)
178-
else:
179-
# int to float
180-
if dtype.is_floating_point:
181-
max = torch.iinfo(image.dtype).max
182-
image = image.to(dtype)
183-
return image / max
184-
185-
# int to int
186-
input_max = torch.iinfo(image.dtype).max
187-
output_max = torch.iinfo(dtype).max
188-
189-
if input_max > output_max:
190-
factor = (input_max + 1) // (output_max + 1)
191-
image = image // factor
192-
return image.to(dtype)
193-
else:
194-
factor = (output_max + 1) // (input_max + 1)
195-
image = image.to(dtype)
196-
return image * factor
155+
if not isinstance(image, torch.Tensor):
156+
raise TypeError('Input img should be Tensor Image')
157+
158+
return F_t.convert_image_dtype(image, dtype)
197159

198160

199161
def to_pil_image(pic, mode=None):

torchvision/transforms/functional_tensor.py

Lines changed: 97 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,101 @@ def _get_image_num_channels(img: Tensor) -> int:
2727
raise TypeError("Input ndim should be 2 or more. Got {}".format(img.ndim))
2828

2929

30+
def _max_value(dtype: torch.dtype) -> float:
31+
# TODO: replace this method with torch.iinfo when it gets torchscript support.
32+
# https://github.com/pytorch/pytorch/issues/41492
33+
34+
a = torch.tensor(2, dtype=dtype)
35+
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
36+
bits = 1
37+
max_value = torch.tensor(-signed, dtype=torch.long)
38+
while True:
39+
next_value = a.pow(bits - signed).sub(1)
40+
if next_value > max_value:
41+
max_value = next_value
42+
bits *= 2
43+
else:
44+
return max_value.item()
45+
return max_value.item()
46+
47+
48+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
49+
"""PRIVATE METHOD. Convert a tensor image to the given ``dtype`` and scale the values accordingly
50+
51+
.. warning::
52+
53+
Module ``transforms.functional_tensor`` is private and should not be used in user application.
54+
Please, consider instead using methods from `transforms.functional` module.
55+
56+
Args:
57+
image (torch.Tensor): Image to be converted
58+
dtype (torch.dtype): Desired data type of the output
59+
60+
Returns:
61+
(torch.Tensor): Converted image
62+
63+
.. note::
64+
65+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
66+
If converted back and forth, this mismatch has no effect.
67+
68+
Raises:
69+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
70+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
71+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
72+
of the integer ``dtype``.
73+
"""
74+
if image.dtype == dtype:
75+
return image
76+
77+
# TODO: replace with image.dtype.is_floating_point when torchscript supports it
78+
if torch.empty(0, dtype=image.dtype).is_floating_point():
79+
80+
# TODO: replace with dtype.is_floating_point when torchscript supports it
81+
if torch.tensor(0, dtype=dtype).is_floating_point():
82+
return image.to(dtype)
83+
84+
# float to int
85+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
86+
image.dtype == torch.float64 and dtype == torch.int64
87+
):
88+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
89+
raise RuntimeError(msg)
90+
91+
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
92+
# For data in the range 0-1, (float * 255).to(uint) is only 255
93+
# when float is exactly 1.0.
94+
# `max + 1 - epsilon` provides more evenly distributed mapping of
95+
# ranges of floats to ints.
96+
eps = 1e-3
97+
max_val = _max_value(dtype)
98+
result = image.mul(max_val + 1.0 - eps)
99+
return result.to(dtype)
100+
else:
101+
input_max = _max_value(image.dtype)
102+
output_max = _max_value(dtype)
103+
104+
# int to float
105+
# TODO: replace with dtype.is_floating_point when torchscript supports it
106+
if torch.tensor(0, dtype=dtype).is_floating_point():
107+
image = image.to(dtype)
108+
return image / input_max
109+
110+
# int to int
111+
if input_max > output_max:
112+
# factor should be forced to int for torch jit script
113+
# otherwise factor is a float and image // factor can produce different results
114+
factor = int((input_max + 1) // (output_max + 1))
115+
image = image // factor
116+
return image.to(dtype)
117+
else:
118+
# factor should be forced to int for torch jit script
119+
# otherwise factor is a float and image * factor can produce different results
120+
factor = int((output_max + 1) // (input_max + 1))
121+
image = image.to(dtype)
122+
return image * factor
123+
124+
30125
def vflip(img: Tensor) -> Tensor:
31126
"""PRIVATE METHOD. Vertically flip the given the Image Tensor.
32127
@@ -302,13 +397,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
302397
result = img
303398
dtype = img.dtype
304399
if not torch.is_floating_point(img):
305-
result = result / 255.0
400+
result = convert_image_dtype(result, torch.float32)
306401

307402
result = (gain * result ** gamma).clamp(0, 1)
308403

309-
if result.dtype != dtype:
310-
eps = 1e-3
311-
result = (255 + 1.0 - eps) * result
404+
result = convert_image_dtype(result, dtype)
312405
result = result.to(dtype)
313406
return result
314407

torchvision/transforms/transforms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from . import functional as F
1818

19-
2019
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
2120
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
2221
"RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop",
@@ -127,7 +126,7 @@ def __repr__(self):
127126
return self.__class__.__name__ + '()'
128127

129128

130-
class ConvertImageDtype:
129+
class ConvertImageDtype(torch.nn.Module):
131130
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
132131
133132
Args:
@@ -146,9 +145,10 @@ class ConvertImageDtype:
146145
"""
147146

148147
def __init__(self, dtype: torch.dtype) -> None:
148+
super().__init__()
149149
self.dtype = dtype
150150

151-
def __call__(self, image: torch.Tensor) -> torch.Tensor:
151+
def forward(self, image: torch.Tensor) -> torch.Tensor:
152152
return F.convert_image_dtype(image, self.dtype)
153153

154154

0 commit comments

Comments
 (0)