Skip to content

Commit e0e73af

Browse files
fmassafacebook-github-bot
authored andcommitted
[fbsync] [TRANS, IMP] Add new max_size parameter to Resize (#3494)
Summary: * WIP, still needs tests and docs * tests * flake8 * Docs + fixed some tests * proper error messages Reviewed By: NicolasHug, cpuhrsch Differential Revision: D26945732 fbshipit-source-id: 765c48af203ba27894881dea596f94d2f4a6794d
1 parent d4ff1ba commit e0e73af

File tree

7 files changed

+164
-98
lines changed

7 files changed

+164
-98
lines changed

test/test_functional_tensor.py

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from common_utils import TransformsTester
1515

16-
from typing import Dict, List, Tuple
16+
from typing import Dict, List, Sequence, Tuple
1717

1818

1919
NEAREST, BILINEAR, BICUBIC = InterpolationMode.NEAREST, InterpolationMode.BILINEAR, InterpolationMode.BICUBIC
@@ -409,46 +409,58 @@ def test_resize(self):
409409
batch_tensors = batch_tensors.to(dt)
410410

411411
for size in [32, 26, [32, ], [32, 32], (32, 32), [26, 35]]:
412-
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
413-
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
414-
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation)
415-
416-
self.assertEqual(
417-
resized_tensor.size()[1:], resized_pil_img.size[::-1], msg="{}, {}".format(size, interpolation)
418-
)
419-
420-
if interpolation not in [NEAREST, ]:
421-
# We can not check values if mode = NEAREST, as results are different
422-
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
423-
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
424-
resized_tensor_f = resized_tensor
425-
# we need to cast to uint8 to compare with PIL image
426-
if resized_tensor_f.dtype == torch.uint8:
427-
resized_tensor_f = resized_tensor_f.to(torch.float)
428-
429-
# Pay attention to high tolerance for MAE
430-
self.approxEqualTensorToPIL(
431-
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
412+
for max_size in (None, 33, 40, 1000):
413+
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
414+
continue # unsupported, see assertRaises below
415+
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
416+
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation, max_size=max_size)
417+
resized_pil_img = F.resize(pil_img, size=size, interpolation=interpolation, max_size=max_size)
418+
419+
self.assertEqual(
420+
resized_tensor.size()[1:], resized_pil_img.size[::-1],
421+
msg="{}, {}".format(size, interpolation)
432422
)
433423

434-
if isinstance(size, int):
435-
script_size = [size, ]
436-
else:
437-
script_size = size
424+
if interpolation not in [NEAREST, ]:
425+
# We can not check values if mode = NEAREST, as results are different
426+
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
427+
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
428+
resized_tensor_f = resized_tensor
429+
# we need to cast to uint8 to compare with PIL image
430+
if resized_tensor_f.dtype == torch.uint8:
431+
resized_tensor_f = resized_tensor_f.to(torch.float)
432+
433+
# Pay attention to high tolerance for MAE
434+
self.approxEqualTensorToPIL(
435+
resized_tensor_f, resized_pil_img, tol=8.0, msg="{}, {}".format(size, interpolation)
436+
)
438437

439-
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation)
440-
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
438+
if isinstance(size, int):
439+
script_size = [size, ]
440+
else:
441+
script_size = size
441442

442-
self._test_fn_on_batch(
443-
batch_tensors, F.resize, size=script_size, interpolation=interpolation
444-
)
443+
resize_result = script_fn(tensor, size=script_size, interpolation=interpolation,
444+
max_size=max_size)
445+
self.assertTrue(resized_tensor.equal(resize_result), msg="{}, {}".format(size, interpolation))
446+
447+
self._test_fn_on_batch(
448+
batch_tensors, F.resize, size=script_size, interpolation=interpolation, max_size=max_size
449+
)
445450

446451
# assert changed type warning
447452
with self.assertWarnsRegex(UserWarning, r"Argument interpolation should be of type InterpolationMode"):
448453
res1 = F.resize(tensor, size=32, interpolation=2)
449454
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
450455
self.assertTrue(res1.equal(res2))
451456

457+
for img in (tensor, pil_img):
458+
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
459+
with self.assertRaisesRegex(ValueError, exp_msg):
460+
F.resize(img, size=(32, 34), max_size=35)
461+
with self.assertRaisesRegex(ValueError, "max_size = 32 must be strictly greater"):
462+
F.resize(img, size=32, max_size=32)
463+
452464
def test_resized_crop(self):
453465
# test values of F.resized_crop in several cases:
454466
# 1) resize to the same size, crop to the same size => should be identity

test/test_transforms.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -312,23 +312,30 @@ def test_resize(self):
312312
img = Image.new("RGB", size=(width, height), color=127)
313313

314314
for osize in test_output_sizes_1:
315-
316-
t = transforms.Resize(osize)
317-
result = t(img)
318-
319-
msg = "{}, {} - {}".format(height, width, osize)
320-
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
321-
# If size is an int, smaller edge of the image will be matched to this number.
322-
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
323-
if height < width:
324-
expected_size = (int(osize * width / height), osize) # (w, h)
325-
self.assertEqual(result.size, expected_size, msg=msg)
326-
elif width < height:
327-
expected_size = (osize, int(osize * height / width)) # (w, h)
328-
self.assertEqual(result.size, expected_size, msg=msg)
329-
else:
330-
expected_size = (osize, osize) # (w, h)
331-
self.assertEqual(result.size, expected_size, msg=msg)
315+
for max_size in (None, 37, 1000):
316+
317+
t = transforms.Resize(osize, max_size=max_size)
318+
result = t(img)
319+
320+
msg = "{}, {} - {} - {}".format(height, width, osize, max_size)
321+
osize = osize[0] if isinstance(osize, (list, tuple)) else osize
322+
# If size is an int, smaller edge of the image will be matched to this number.
323+
# i.e, if height > width, then image will be rescaled to (size * height / width, size).
324+
if height < width:
325+
exp_w, exp_h = (int(osize * width / height), osize) # (w, h)
326+
if max_size is not None and max_size < exp_w:
327+
exp_w, exp_h = max_size, int(max_size * exp_h / exp_w)
328+
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
329+
elif width < height:
330+
exp_w, exp_h = (osize, int(osize * height / width)) # (w, h)
331+
if max_size is not None and max_size < exp_h:
332+
exp_w, exp_h = int(max_size * exp_w / exp_h), max_size
333+
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
334+
else:
335+
exp_w, exp_h = (osize, osize) # (w, h)
336+
if max_size is not None and max_size < osize:
337+
exp_w, exp_h = max_size, max_size
338+
self.assertEqual(result.size, (exp_w, exp_h), msg=msg)
332339

333340
for height, width in input_sizes:
334341
img = Image.new("RGB", size=(width, height), color=127)

test/test_transforms_tensor.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88

99
import unittest
10+
from typing import Sequence
1011

1112
from common_utils import TransformsTester, get_tmp_dir, int_dtypes, float_dtypes
1213

@@ -322,32 +323,29 @@ def test_resize(self):
322323

323324
tensor, _ = self._create_data(height=34, width=36, device=self.device)
324325
batch_tensors = torch.randint(0, 256, size=(4, 3, 44, 56), dtype=torch.uint8, device=self.device)
325-
script_fn = torch.jit.script(F.resize)
326326

327327
for dt in [None, torch.float32, torch.float64]:
328328
if dt is not None:
329329
# This is a trivial cast to float of uint8 data to test all cases
330330
tensor = tensor.to(dt)
331331
for size in [32, 34, [32, ], [32, 32], (32, 32), [34, 35]]:
332-
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
332+
for max_size in (None, 35, 1000):
333+
if max_size is not None and isinstance(size, Sequence) and len(size) != 1:
334+
continue # Not supported
335+
for interpolation in [BILINEAR, BICUBIC, NEAREST]:
333336

334-
resized_tensor = F.resize(tensor, size=size, interpolation=interpolation)
337+
if isinstance(size, int):
338+
script_size = [size, ]
339+
else:
340+
script_size = size
335341

336-
if isinstance(size, int):
337-
script_size = [size, ]
338-
else:
339-
script_size = size
340-
341-
s_resized_tensor = script_fn(tensor, size=script_size, interpolation=interpolation)
342-
self.assertTrue(s_resized_tensor.equal(resized_tensor))
343-
344-
transform = T.Resize(size=script_size, interpolation=interpolation)
345-
s_transform = torch.jit.script(transform)
346-
self._test_transform_vs_scripted(transform, s_transform, tensor)
347-
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
342+
transform = T.Resize(size=script_size, interpolation=interpolation, max_size=max_size)
343+
s_transform = torch.jit.script(transform)
344+
self._test_transform_vs_scripted(transform, s_transform, tensor)
345+
self._test_transform_vs_scripted_on_batch(transform, s_transform, batch_tensors)
348346

349347
with get_tmp_dir() as tmp_dir:
350-
script_fn.save(os.path.join(tmp_dir, "t_resize.pt"))
348+
s_transform.save(os.path.join(tmp_dir, "t_resize.pt"))
351349

352350
def test_resized_crop(self):
353351
tensor = torch.randint(0, 256, size=(3, 44, 56), dtype=torch.uint8, device=self.device)

torchvision/transforms/functional.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,8 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool
337337
return tensor
338338

339339

340-
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR) -> Tensor:
340+
def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR,
341+
max_size: Optional[int] = None) -> Tensor:
341342
r"""Resize the input image to the given size.
342343
If the image is torch Tensor, it is expected
343344
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
@@ -355,6 +356,14 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
355356
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
356357
``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported.
357358
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
359+
max_size (int, optional): The maximum allowed for the longer edge of
360+
the resized image: if the longer edge of the image is greater
361+
than ``max_size`` after being resized according to ``size``, then
362+
the image is resized again so that the longer edge is equal to
363+
``max_size``. As a result, ```size` might be overruled, i.e the
364+
smaller edge may be shorter than ``size``. This is only supported
365+
if ``size`` is an int (or a sequence of length 1 in torchscript
366+
mode).
358367
359368
Returns:
360369
PIL Image or Tensor: Resized image.
@@ -372,9 +381,9 @@ def resize(img: Tensor, size: List[int], interpolation: InterpolationMode = Inte
372381

373382
if not isinstance(img, torch.Tensor):
374383
pil_interpolation = pil_modes_mapping[interpolation]
375-
return F_pil.resize(img, size=size, interpolation=pil_interpolation)
384+
return F_pil.resize(img, size=size, interpolation=pil_interpolation, max_size=max_size)
376385

377-
return F_t.resize(img, size=size, interpolation=interpolation.value)
386+
return F_t.resize(img, size=size, interpolation=interpolation.value, max_size=max_size)
378387

379388

380389
def scale(*args, **kwargs):

torchvision/transforms/functional_pil.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -204,27 +204,40 @@ def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Imag
204204

205205

206206
@torch.jit.unused
207-
def resize(img, size, interpolation=Image.BILINEAR):
207+
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
208208
if not _is_pil_image(img):
209209
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
210210
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
211211
raise TypeError('Got inappropriate size arg: {}'.format(size))
212212

213-
if isinstance(size, int) or len(size) == 1:
214-
if isinstance(size, Sequence):
215-
size = size[0]
213+
if isinstance(size, Sequence) and len(size) == 1:
214+
size = size[0]
215+
if isinstance(size, int):
216216
w, h = img.size
217-
if (w <= h and w == size) or (h <= w and h == size):
217+
218+
short, long = (w, h) if w <= h else (h, w)
219+
if short == size:
218220
return img
219-
if w < h:
220-
ow = size
221-
oh = int(size * h / w)
222-
return img.resize((ow, oh), interpolation)
223-
else:
224-
oh = size
225-
ow = int(size * w / h)
226-
return img.resize((ow, oh), interpolation)
221+
222+
new_short, new_long = size, int(size * long / short)
223+
224+
if max_size is not None:
225+
if max_size <= size:
226+
raise ValueError(
227+
f"max_size = {max_size} must be strictly greater than the requested "
228+
f"size for the smaller edge size = {size}"
229+
)
230+
if new_long > max_size:
231+
new_short, new_long = int(max_size * new_short / new_long), max_size
232+
233+
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
234+
return img.resize((new_w, new_h), interpolation)
227235
else:
236+
if max_size is not None:
237+
raise ValueError(
238+
"max_size should only be passed if size specifies the length of the smaller edge, "
239+
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
240+
)
228241
return img.resize(size[::-1], interpolation)
229242

230243

torchvision/transforms/functional_tensor.py

Lines changed: 33 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -470,7 +470,7 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
470470
return img
471471

472472

473-
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Tensor:
473+
def resize(img: Tensor, size: List[int], interpolation: str = "bilinear", max_size: Optional[int] = None) -> Tensor:
474474
_assert_image_tensor(img)
475475

476476
if not isinstance(size, (int, tuple, list)):
@@ -484,34 +484,51 @@ def resize(img: Tensor, size: List[int], interpolation: str = "bilinear") -> Ten
484484
if isinstance(size, tuple):
485485
size = list(size)
486486

487-
if isinstance(size, list) and len(size) not in [1, 2]:
488-
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
489-
"{} element tuple/list".format(len(size)))
487+
if isinstance(size, list):
488+
if len(size) not in [1, 2]:
489+
raise ValueError("Size must be an int or a 1 or 2 element tuple/list, not a "
490+
"{} element tuple/list".format(len(size)))
491+
if max_size is not None and len(size) != 1:
492+
raise ValueError(
493+
"max_size should only be passed if size specifies the length of the smaller edge, "
494+
"i.e. size should be an int or a sequence of length 1 in torchscript mode."
495+
)
490496

491497
w, h = _get_image_size(img)
492498

493-
if isinstance(size, int):
494-
size_w, size_h = size, size
495-
elif len(size) < 2:
496-
size_w, size_h = size[0], size[0]
497-
else:
498-
size_w, size_h = size[1], size[0] # Convention (h, w)
499+
if isinstance(size, int) or len(size) == 1: # specified size only for the smallest edge
500+
short, long = (w, h) if w <= h else (h, w)
499501

500-
if isinstance(size, int) or len(size) < 2:
501-
if w < h:
502-
size_h = int(size_w * h / w)
502+
if isinstance(size, int):
503+
requested_new_short = size
503504
else:
504-
size_w = int(size_h * w / h)
505+
requested_new_short = size[0]
505506

506-
if (w <= h and w == size_w) or (h <= w and h == size_h):
507+
if short == requested_new_short:
507508
return img
508509

510+
new_short, new_long = requested_new_short, int(requested_new_short * long / short)
511+
512+
if max_size is not None:
513+
if max_size <= requested_new_short:
514+
raise ValueError(
515+
f"max_size = {max_size} must be strictly greater than the requested "
516+
f"size for the smaller edge size = {size}"
517+
)
518+
if new_long > max_size:
519+
new_short, new_long = int(max_size * new_short / new_long), max_size
520+
521+
new_w, new_h = (new_short, new_long) if w <= h else (new_long, new_short)
522+
523+
else: # specified both h and w
524+
new_w, new_h = size[1], size[0]
525+
509526
img, need_cast, need_squeeze, out_dtype = _cast_squeeze_in(img, [torch.float32, torch.float64])
510527

511528
# Define align_corners to avoid warnings
512529
align_corners = False if interpolation in ["bilinear", "bicubic"] else None
513530

514-
img = interpolate(img, size=[size_h, size_w], mode=interpolation, align_corners=align_corners)
531+
img = interpolate(img, size=[new_h, new_w], mode=interpolation, align_corners=align_corners)
515532

516533
if interpolation == "bicubic" and out_dtype == torch.uint8:
517534
img = img.clamp(min=0, max=255)

0 commit comments

Comments
 (0)