Skip to content

Commit f610463

Browse files
committed
move convert dtype to functional_tensor since only works on tensors
1 parent 794bbb1 commit f610463

File tree

5 files changed

+90
-91
lines changed

5 files changed

+90
-91
lines changed

test/test_functional_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,7 +309,7 @@ def test_adjust_gamma(self):
309309
for dt in [torch.float64, torch.float32, None]:
310310

311311
if dt is not None:
312-
tensor = F.convert_image_dtype(tensor, dt)
312+
tensor = F_t.convert_image_dtype(tensor, dt)
313313

314314
gammas = [0.8, 1.0, 1.2]
315315
gains = [0.7, 1.0, 1.3]
@@ -323,7 +323,7 @@ def test_adjust_gamma(self):
323323

324324
rbg_tensor = adjusted_tensor
325325
if adjusted_tensor.dtype != torch.uint8:
326-
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)
326+
rbg_tensor = F_t.convert_image_dtype(adjusted_tensor, torch.uint8)
327327

328328
self.compareTensorToPIL(rbg_tensor, adjusted_pil)
329329

test/test_transforms.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch
33
import torchvision.transforms as transforms
44
import torchvision.transforms.functional as F
5+
import torchvision.transforms.functional_tensor as F_t
56
from torch._utils_internal import get_file_path_2
67
from numpy.testing import assert_array_almost_equal
78
import unittest
@@ -528,7 +529,7 @@ def test_to_tensor(self):
528529

529530
def test_max_value(self):
530531
for dtype in int_dtypes():
531-
self.assertEqual(F._max_value(dtype), torch.iinfo(dtype).max)
532+
self.assertEqual(F_t._max_value(dtype), torch.iinfo(dtype).max)
532533

533534
def test_convert_image_dtype_float_to_float(self):
534535
for input_dtype, output_dtypes in cycle_over(float_dtypes()):

torchvision/transforms/functional.py

Lines changed: 0 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -124,88 +124,6 @@ def pil_to_tensor(pic):
124124
return img
125125

126126

127-
# torch.iinfo isn't scriptable so using this helper function
128-
# https://github.com/pytorch/pytorch/issues/41492
129-
def _max_value(dtype: int) -> int:
130-
a = torch.tensor(2, dtype=dtype)
131-
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
132-
bits = 1
133-
max_value = torch.tensor(-signed, dtype=torch.long)
134-
while(True):
135-
next_value = a.pow(bits - signed).sub(1)
136-
if next_value > max_value:
137-
max_value = next_value
138-
bits *= 2
139-
else:
140-
return max_value.item()
141-
return max_value.item()
142-
143-
144-
def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor:
145-
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
146-
147-
Args:
148-
image (torch.Tensor): Image to be converted
149-
dtype (torch.dtype): Desired data type of the output
150-
151-
Returns:
152-
(torch.Tensor): Converted image
153-
154-
.. note::
155-
156-
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
157-
If converted back and forth, this mismatch has no effect.
158-
159-
Raises:
160-
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
161-
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
162-
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
163-
of the integer ``dtype``.
164-
"""
165-
if image.dtype == dtype:
166-
return image
167-
168-
if torch.empty(0, dtype=image.dtype).is_floating_point():
169-
# float to float
170-
if torch.tensor(0, dtype=dtype).is_floating_point():
171-
return image.to(dtype)
172-
173-
# float to int
174-
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
175-
image.dtype == torch.float64 and dtype == torch.int64
176-
):
177-
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
178-
raise RuntimeError(msg)
179-
180-
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
181-
# For data in the range 0-1, (float * 255).to(uint) is only 255
182-
# when float is exactly 1.0.
183-
# `max + 1 - epsilon` provides more evenly distributed mapping of
184-
# ranges of floats to ints.
185-
eps = 1e-3
186-
max_val = _max_value(dtype)
187-
result = image.mul(max_val + 1.0 - eps)
188-
return result.to(dtype)
189-
else:
190-
input_max = _max_value(image.dtype)
191-
output_max = _max_value(dtype)
192-
193-
# int to float
194-
if torch.tensor(0, dtype=dtype).is_floating_point():
195-
image = image.to(dtype)
196-
return image / input_max
197-
198-
# int to int
199-
if input_max > output_max:
200-
factor = (input_max + 1) // (output_max + 1)
201-
image = image // factor
202-
return image.to(dtype)
203-
else:
204-
factor = (output_max + 1) // (input_max + 1)
205-
image = image.to(dtype)
206-
return image * factor
207-
208-
209127
def to_pil_image(pic, mode=None):
210128
"""Convert a tensor or an ndarray to PIL Image.
211129

torchvision/transforms/functional_tensor.py

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from torch.nn.functional import affine_grid, grid_sample
77
from torch.jit.annotations import List, BroadcastingList2
88

9-
import torchvision.transforms.functional as F
10-
119

1210
def _is_tensor_a_torch_image(x: Tensor) -> bool:
1311
return x.ndim >= 2
@@ -20,6 +18,88 @@ def _get_image_size(img: Tensor) -> List[int]:
2018
raise TypeError("Unexpected type {}".format(type(img)))
2119

2220

21+
# torch.iinfo isn't scriptable so using this helper function
22+
# https://github.com/pytorch/pytorch/issues/41492
23+
def _max_value(dtype: int) -> int:
24+
a = torch.tensor(2, dtype=dtype)
25+
signed = 1 if torch.tensor(0, dtype=dtype).is_signed() else 0
26+
bits = 1
27+
max_value = torch.tensor(-signed, dtype=torch.long)
28+
while(True):
29+
next_value = a.pow(bits - signed).sub(1)
30+
if next_value > max_value:
31+
max_value = next_value
32+
bits *= 2
33+
else:
34+
return max_value.item()
35+
return max_value.item()
36+
37+
38+
def convert_image_dtype(image: torch.Tensor, dtype: int = torch.float) -> torch.Tensor:
39+
"""Convert a tensor image to the given ``dtype`` and scale the values accordingly
40+
41+
Args:
42+
image (torch.Tensor): Image to be converted
43+
dtype (torch.dtype): Desired data type of the output
44+
45+
Returns:
46+
(torch.Tensor): Converted image
47+
48+
.. note::
49+
50+
When converting from a smaller to a larger integer ``dtype`` the maximum values are **not** mapped exactly.
51+
If converted back and forth, this mismatch has no effect.
52+
53+
Raises:
54+
RuntimeError: When trying to cast :class:`torch.float32` to :class:`torch.int32` or :class:`torch.int64` as
55+
well as for trying to cast :class:`torch.float64` to :class:`torch.int64`. These conversions might lead to
56+
overflow errors since the floating point ``dtype`` cannot store consecutive integers over the whole range
57+
of the integer ``dtype``.
58+
"""
59+
if image.dtype == dtype:
60+
return image
61+
62+
if torch.empty(0, dtype=image.dtype).is_floating_point():
63+
# float to float
64+
if torch.tensor(0, dtype=dtype).is_floating_point():
65+
return image.to(dtype)
66+
67+
# float to int
68+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
69+
image.dtype == torch.float64 and dtype == torch.int64
70+
):
71+
msg = f"The cast from {image.dtype} to {dtype} cannot be performed safely."
72+
raise RuntimeError(msg)
73+
74+
# https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
75+
# For data in the range 0-1, (float * 255).to(uint) is only 255
76+
# when float is exactly 1.0.
77+
# `max + 1 - epsilon` provides more evenly distributed mapping of
78+
# ranges of floats to ints.
79+
eps = 1e-3
80+
max_val = _max_value(dtype)
81+
result = image.mul(max_val + 1.0 - eps)
82+
return result.to(dtype)
83+
else:
84+
input_max = _max_value(image.dtype)
85+
output_max = _max_value(dtype)
86+
87+
# int to float
88+
if torch.tensor(0, dtype=dtype).is_floating_point():
89+
image = image.to(dtype)
90+
return image / input_max
91+
92+
# int to int
93+
if input_max > output_max:
94+
factor = (input_max + 1) // (output_max + 1)
95+
image = image // factor
96+
return image.to(dtype)
97+
else:
98+
factor = (output_max + 1) // (input_max + 1)
99+
image = image.to(dtype)
100+
return image * factor
101+
102+
23103
def vflip(img: Tensor) -> Tensor:
24104
"""Vertically flip the given the Image Tensor.
25105
@@ -230,11 +310,11 @@ def adjust_gamma(img: Tensor, gamma: float, gain: float = 1) -> Tensor:
230310
result = img
231311
dtype = img.dtype
232312
if not torch.is_floating_point(img):
233-
result = F.convert_image_dtype(result, torch.get_default_dtype())
313+
result = convert_image_dtype(result, torch.float32)
234314

235315
result = (gain * result ** gamma).clamp(0, 1)
236316

237-
result = F.convert_image_dtype(result, dtype)
317+
result = convert_image_dtype(result, dtype)
238318
result = result.to(dtype)
239319
return result
240320

torchvision/transforms/transforms.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
accimage = None
1717

1818
from . import functional as F
19-
19+
from . import functional_tensor as F_t
2020

2121
__all__ = ["Compose", "ToTensor", "PILToTensor", "ConvertImageDtype", "ToPILImage", "Normalize", "Resize", "Scale",
2222
"CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop",
@@ -131,7 +131,7 @@ def __init__(self, dtype: torch.dtype) -> None:
131131
self.dtype = dtype
132132

133133
def __call__(self, image: torch.Tensor) -> torch.Tensor:
134-
return F.convert_image_dtype(image, self.dtype)
134+
return F_t.convert_image_dtype(image, self.dtype)
135135

136136

137137
class ToPILImage(object):

0 commit comments

Comments
 (0)