Skip to content

Commit e212cc8

Browse files
authored
Unified input for resize op (#2394)
* [WIP] F.resize with tensor * Adapted T.Resize and F.resize with a test * According to the review, fixed copy-pasted messages and unused imports
1 parent 971c3e4 commit e212cc8

File tree

6 files changed

+246
-42
lines changed

6 files changed

+246
-42
lines changed

test/test_functional_tensor.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
1-
import torch
2-
import torchvision.transforms as transforms
3-
import torchvision.transforms.functional_tensor as F_t
4-
import torchvision.transforms.functional_pil as F_pil
5-
import torchvision.transforms.functional as F
6-
import numpy as np
71
import unittest
82
import random
93
import colorsys
104

115
from PIL import Image
6+
from PIL.Image import NEAREST, BILINEAR, BICUBIC
7+
8+
import numpy as np
9+
10+
import torch
11+
import torchvision.transforms as transforms
12+
import torchvision.transforms.functional_tensor as F_t
13+
import torchvision.transforms.functional_pil as F_pil
14+
import torchvision.transforms.functional as F
1215

1316

1417
class Tester(unittest.TestCase):
@@ -22,6 +25,14 @@ def compareTensorToPIL(self, tensor, pil_image, msg=None):
2225
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
2326
self.assertTrue(tensor.equal(pil_tensor), msg)
2427

28+
def approxEqualTensorToPIL(self, tensor, pil_image, tol=1e-5, msg=None):
29+
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1))).to(tensor)
30+
mae = torch.abs(tensor - pil_tensor).mean().item()
31+
self.assertTrue(
32+
mae < tol,
33+
msg="{}: mae={}, tol={}: \n{}\nvs\n{}".format(msg, mae, tol, tensor[0, :10, :10], pil_tensor[0, :10, :10])
34+
)
35+
2536
def test_vflip(self):
2637
script_vflip = torch.jit.script(F_t.vflip)
2738
img_tensor = torch.randn(3, 16, 16)
@@ -282,6 +293,44 @@ def test_pad(self):
282293
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
283294
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")
284295

296+
def test_resize(self):
297+
script_fn = torch.jit.script(F_t.resize)
298+
tensor, pil_img = self._create_data(26, 36)
299+
300+
for dt in [None, torch.float32, torch.float64]:
301+
if dt is not None:
302+
# This is a trivial cast to float of uint8 data to test all cases
303+
tensor = tensor.to(dt)
304+
for size in [32, [32, ], [32, 32], (32, 32), ]:
305+
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
306+
resized_tensor = F_t.resize(tensor, size=size, interpolation=interpolation)
307+
resized_pil_img = F_pil.resize(pil_img, size=size, interpolation=interpolation)
308+
309+
self.assertEqual(
310+
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
311+
)
312+
313+
if interpolation != NEAREST:
314+
# We can not check values if mode = NEAREST, as results are different
315+
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
316+
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
317+
resized_tensor_f = resized_tensor
318+
# we need to cast to uint8 to compare with PIL image
319+
if resized_tensor_f.dtype == torch.uint8:
320+
resized_tensor_f = resized_tensor_f.to(torch.float)
321+
322+
# Pay attention to high tolerance for MAE
323+
self.approxEqualTensorToPIL(
324+
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
325+
)
326+
327+
if isinstance(size, int):
328+
script_size = [size, ]
329+
else:
330+
script_size = size
331+
pad_tensor_script = script_fn(tensor, size=script_size, interpolation=interpolation)
332+
self.assertTrue(resized_tensor.equal(pad_tensor_script), msg="{}, {}".format(size, interpolation))
333+
285334

286335
if __name__ == '__main__':
287336
unittest.main()

test/test_transforms_tensor.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from torchvision import transforms as T
33
from torchvision.transforms import functional as F
44
from PIL import Image
5+
from PIL.Image import NEAREST, BILINEAR, BICUBIC
56

67
import numpy as np
78

@@ -217,6 +218,33 @@ def test_ten_crop(self):
217218
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
218219
)
219220

221+
def test_resize(self):
222+
tensor, _ = self._create_data(height=34, width=36)
223+
script_fn = torch.jit.script(F.resize)
224+
225+
for dt in [None, torch.float32, torch.float64]:
226+
if dt is not None:
227+
# This is a trivial cast to float of uint8 data to test all cases
228+
tensor = tensor.to(dt)
229+
for size in [32, [32, ], [32, 32], (32, 32), ]:
230+
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
231+
232+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
233+
234+
if isinstance(size, int):
235+
script_size = [size, ]
236+
else:
237+
script_size = size
238+
239+
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
240+
self.assertTrue(s_resized_tensor.equal(resized_tensor))
241+
242+
transform = T.Resize(size=script_size, interpolation=interpolation)
243+
resized_tensor = transform(tensor)
244+
script_transform = torch.jit.script(transform)
245+
s_resized_tensor = script_transform(tensor)
246+
self.assertTrue(s_resized_tensor.equal(resized_tensor))
247+
220248

221249
if __name__ == '__main__':
222250
unittest.main()

torchvision/transforms/functional.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -311,41 +311,29 @@ def normalize(tensor, mean, std, inplace=False):
311311
return tensor
312312

313313

314-
def resize(img, size, interpolation=Image.BILINEAR):
315-
r"""Resize the input PIL Image to the given size.
314+
def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
315+
r"""Resize the input image to the given size.
316+
The image can be a PIL Image or a torch Tensor, in which case it is expected
317+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
316318
317319
Args:
318-
img (PIL Image): Image to be resized.
320+
img (PIL Image or Tensor): Image to be resized.
319321
size (sequence or int): Desired output size. If size is a sequence like
320322
(h, w), the output size will be matched to this. If size is an int,
321323
the smaller edge of the image will be matched to this number maintaining
322324
the aspect ratio. i.e, if height > width, then image will be rescaled to
323-
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`
324-
interpolation (int, optional): Desired interpolation. Default is
325-
``PIL.Image.BILINEAR``
325+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
326+
In torchscript mode padding as single int is not supported, use a tuple or
327+
list of length 1: ``[size, ]``.
328+
interpolation (int, optional): Desired interpolation. Default is bilinear.
326329
327330
Returns:
328-
PIL Image: Resized image.
331+
PIL Image or Tensor: Resized image.
329332
"""
330-
if not F_pil._is_pil_image(img):
331-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
332-
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
333-
raise TypeError('Got inappropriate size arg: {}'.format(size))
333+
if not isinstance(img, torch.Tensor):
334+
return F_pil.resize(img, size=size, interpolation=interpolation)
334335

335-
if isinstance(size, int):
336-
w, h = img.size
337-
if (w <= h and w == size) or (h <= w and h == size):
338-
return img
339-
if w < h:
340-
ow = size
341-
oh = int(size * h / w)
342-
return img.resize((ow, oh), interpolation)
343-
else:
344-
oh = size
345-
ow = int(size * w / h)
346-
return img.resize((ow, oh), interpolation)
347-
else:
348-
return img.resize(size[::-1], interpolation)
336+
return F_t.resize(img, size=size, interpolation=interpolation)
349337

350338

351339
def scale(*args, **kwargs):

torchvision/transforms/functional_pil.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import numbers
2-
from typing import Any, List
2+
from typing import Any, List, Sequence
33

44
import torch
55
try:
@@ -286,3 +286,44 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
286286
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
287287

288288
return img.crop((left, top, left + width, top + height))
289+
290+
291+
@torch.jit.unused
292+
def resize(img, size, interpolation=Image.BILINEAR):
293+
r"""Resize the input PIL Image to the given size.
294+
295+
Args:
296+
img (PIL Image): Image to be resized.
297+
size (sequence or int): Desired output size. If size is a sequence like
298+
(h, w), the output size will be matched to this. If size is an int,
299+
the smaller edge of the image will be matched to this number maintaining
300+
the aspect ratio. i.e, if height > width, then image will be rescaled to
301+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
302+
For compatibility reasons with ``functional_tensor.resize``, if a tuple or list of length 1 is provided,
303+
it is interpreted as a single int.
304+
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``.
305+
306+
Returns:
307+
PIL Image: Resized image.
308+
"""
309+
if not _is_pil_image(img):
310+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
311+
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
312+
raise TypeError('Got inappropriate size arg: {}'.format(size))
313+
314+
if isinstance(size, int) or len(size) == 1:
315+
if isinstance(size, Sequence):
316+
size = size[0]
317+
w, h = img.size
318+
if (w <= h and w == size) or (h <= w and h == size):
319+
return img
320+
if w < h:
321+
ow = size
322+
oh = int(size * h / w)
323+
return img.resize((ow, oh), interpolation)
324+
else:
325+
oh = size
326+
ow = int(size * w / h)
327+
return img.resize((ow, oh), interpolation)
328+
else:
329+
return img.resize(size[::-1], interpolation)

torchvision/transforms/functional_tensor.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ def _is_tensor_a_torch_image(x: Tensor) -> bool:
88

99

1010
def _get_image_size(img: Tensor) -> List[int]:
11+
"""Returns (w, h) of tensor image"""
1112
if _is_tensor_a_torch_image(img):
1213
return [img.shape[-1], img.shape[-2]]
1314
raise TypeError("Unexpected type {}".format(type(img)))
@@ -433,6 +434,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
433434

434435
if isinstance(padding, int):
435436
if torch.jit.is_scripting():
437+
# This maybe unreachable
436438
raise ValueError("padding can't be an int while torchscripting, set it as a list [value, ]")
437439
pad_left = pad_right = pad_top = pad_bottom = padding
438440
elif len(padding) == 1:
@@ -480,3 +482,92 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
480482
img = img.to(out_dtype)
481483

482484
return img
485+
486+
487+
def resize(img: Tensor, size: List[int], interpolation: int = 2) -> Tensor:
488+
r"""Resize the input Tensor to the given size.
489+
490+
Args:
491+
img (Tensor): Image to be resized.
492+
size (int or tuple or list): Desired output size. If size is a sequence like
493+
(h, w), the output size will be matched to this. If size is an int,
494+
the smaller edge of the image will be matched to this number maintaining
495+
the aspect ratio. i.e, if height > width, then image will be rescaled to
496+
:math:`\left(\text{size} \times \frac{\text{height}}{\text{width}}, \text{size}\right)`.
497+
In torchscript mode padding as a single int is not supported, use a tuple or
498+
list of length 1: ``[size, ]``.
499+
interpolation (int, optional): Desired interpolation. Default is bilinear.
500+
501+
Returns:
502+
Tensor: Resized image.
503+
"""
504+
if not _is_tensor_a_torch_image(img):
505+
raise TypeError("tensor is not a torch image.")
506+
507+
if not isinstance(size, (int, tuple, list)):
508+
raise TypeError("Got inappropriate size arg")
509+
if not isinstance(interpolation, int):
510+
raise TypeError("Got inappropriate interpolation arg")
511+
512+
_interpolation_modes = {
513+
0: "nearest",
514+
2: "bilinear",
515+
3: "bicubic",
516+
}
517+
518+
if interpolation not in _interpolation_modes:
519+
raise ValueError("This interpolation mode is unsupported with Tensor input")
520+
521+
if isinstance(size, tuple):
522+
size = list(size)
523+
524+
if isinstance(size, list) and len(size) not in [1, 2]:
525+
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
526+
"{} element tuple/list".format(len(size)))
527+
528+
w, h = _get_image_size(img)
529+
530+
if isinstance(size, int):
531+
size_w, size_h = size, size
532+
elif len(size) < 2:
533+
size_w, size_h = size[0], size[0]
534+
else:
535+
size_w, size_h = size[0], size[1]
536+
537+
if isinstance(size, int) or len(size) < 2:
538+
if w < h:
539+
size_h = int(size_w * h / w)
540+
else:
541+
size_w = int(size_h * w / h)
542+
543+
if (w <= h and w == size_w) or (h <= w and h == size_h):
544+
return img
545+
546+
# make image NCHW
547+
need_squeeze = False
548+
if img.ndim < 4:
549+
img = img.unsqueeze(dim=0)
550+
need_squeeze = True
551+
552+
mode = _interpolation_modes[interpolation]
553+
554+
out_dtype = img.dtype
555+
need_cast = False
556+
if img.dtype not in (torch.float32, torch.float64):
557+
need_cast = True
558+
img = img.to(torch.float32)
559+
560+
# Define align_corners to avoid warnings
561+
align_corners = False if mode in ["bilinear", "bicubic"] else None
562+
563+
img = torch.nn.functional.interpolate(img, size=(size_h, size_w), mode=mode, align_corners=align_corners)
564+
565+
if need_squeeze:
566+
img = img.squeeze(dim=0)
567+
568+
if need_cast:
569+
if mode == "bicubic":
570+
img = img.clamp(min=0, max=255)
571+
img = img.to(out_dtype)
572+
573+
return img

torchvision/transforms/transforms.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import numbers
33
import random
44
import warnings
5-
from collections.abc import Sequence, Iterable
5+
from collections.abc import Sequence
66
from typing import Tuple, List, Optional
77

88
import numpy as np
@@ -209,31 +209,38 @@ def __repr__(self):
209209
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
210210

211211

212-
class Resize(object):
213-
"""Resize the input PIL Image to the given size.
212+
class Resize(torch.nn.Module):
213+
"""Resize the input image to the given size.
214+
The image can be a PIL Image or a torch Tensor, in which case it is expected
215+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
214216
215217
Args:
216218
size (sequence or int): Desired output size. If size is a sequence like
217219
(h, w), output size will be matched to this. If size is an int,
218220
smaller edge of the image will be matched to this number.
219221
i.e, if height > width, then image will be rescaled to
220-
(size * height / width, size)
221-
interpolation (int, optional): Desired interpolation. Default is
222-
``PIL.Image.BILINEAR``
222+
(size * height / width, size).
223+
In torchscript mode padding as single int is not supported, use a tuple or
224+
list of length 1: ``[size, ]``.
225+
interpolation (int, optional): Desired interpolation. Default is ``PIL.Image.BILINEAR``
223226
"""
224227

225228
def __init__(self, size, interpolation=Image.BILINEAR):
226-
assert isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)
229+
super().__init__()
230+
if not isinstance(size, (int, Sequence)):
231+
raise TypeError("Size should be int or sequence. Got {}".format(type(size)))
232+
if isinstance(size, Sequence) and len(size) not in (1, 2):
233+
raise ValueError("If size is a sequence, it should have 1 or 2 values")
227234
self.size = size
228235
self.interpolation = interpolation
229236

230-
def __call__(self, img):
237+
def forward(self, img):
231238
"""
232239
Args:
233-
img (PIL Image): Image to be scaled.
240+
img (PIL Image or Tensor): Image to be scaled.
234241
235242
Returns:
236-
PIL Image: Rescaled image.
243+
PIL Image or Tensor: Rescaled image.
237244
"""
238245
return F.resize(img, self.size, self.interpolation)
239246

0 commit comments

Comments
 (0)