Skip to content

Commit b148f6b

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Undeprecate PIL int constants for interpolation (#7241)
Reviewed By: vmoens Differential Revision: D44416623 fbshipit-source-id: e404afad4f4e9cab00c0365410cdeb588179378b
1 parent a3572e4 commit b148f6b

8 files changed

+110
-73
lines changed

test/test_functional_tensor.py

Lines changed: 31 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Sequence
88

99
import numpy as np
10+
import PIL.Image
1011
import pytest
1112
import torch
1213
import torchvision.transforms as T
@@ -144,6 +145,12 @@ def test_rotate_batch(self, device, dt):
144145
center = (20, 22)
145146
_test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center)
146147

148+
def test_rotate_interpolation_type(self):
149+
tensor, _ = _create_data(26, 26)
150+
res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
151+
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
152+
assert_equal(res1, res2)
153+
147154

148155
class TestAffine:
149156

@@ -350,6 +357,14 @@ def test_batches(self, device, dt):
350357

351358
_test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])
352359

360+
@pytest.mark.parametrize("device", cpu_and_gpu())
361+
def test_interpolation_type(self, device):
362+
tensor, pil_img = _create_data(26, 26, device=device)
363+
364+
res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
365+
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
366+
assert_equal(res1, res2)
367+
353368

354369
def _get_data_dims_and_points_for_perspective():
355370
# Ideally we would parametrize independently over data dims and points, but
@@ -448,6 +463,16 @@ def test_perspective_batch(device, dims_and_points, dt):
448463
)
449464

450465

466+
def test_perspective_interpolation_type():
467+
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
468+
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
469+
tensor = torch.randint(0, 256, (3, 26, 26))
470+
471+
res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
472+
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
473+
assert_equal(res1, res2)
474+
475+
451476
@pytest.mark.parametrize("device", cpu_and_gpu())
452477
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
453478
@pytest.mark.parametrize(
@@ -489,9 +514,7 @@ def test_resize(device, dt, size, max_size, interpolation):
489514

490515
assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]
491516

492-
if interpolation not in [
493-
NEAREST,
494-
]:
517+
if interpolation != NEAREST:
495518
# We can not check values if mode = NEAREST, as results are different
496519
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
497520
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
@@ -504,9 +527,7 @@ def test_resize(device, dt, size, max_size, interpolation):
504527
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)
505528

506529
if isinstance(size, int):
507-
script_size = [
508-
size,
509-
]
530+
script_size = [size]
510531
else:
511532
script_size = size
512533

@@ -523,6 +544,10 @@ def test_resize_asserts(device):
523544

524545
tensor, pil_img = _create_data(26, 36, device=device)
525546

547+
res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
548+
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
549+
assert_equal(res1, res2)
550+
526551
for img in (tensor, pil_img):
527552
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
528553
with pytest.raises(ValueError, match=exp_msg):

test/test_onnx.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,13 +407,12 @@ def forward(self_module, images, features):
407407
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
408408
import os
409409

410-
import torchvision.transforms._pil_constants as _pil_constants
411410
from PIL import Image
412411
from torchvision.transforms import functional as F
413412

414413
data_dir = os.path.join(os.path.dirname(__file__), "assets")
415414
path = os.path.join(data_dir, *rel_path.split("/"))
416-
image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
415+
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)
417416

418417
return F.convert_image_dtype(F.pil_to_tensor(image))
419418

test/test_transforms.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
import pytest
1010
import torch
1111
import torchvision.transforms as transforms
12-
import torchvision.transforms._pil_constants as _pil_constants
1312
import torchvision.transforms.functional as F
1413
import torchvision.transforms.functional_tensor as F_t
1514
from PIL import Image
@@ -175,7 +174,7 @@ def test_accimage_pil_to_tensor(self):
175174
def test_accimage_resize(self):
176175
trans = transforms.Compose(
177176
[
178-
transforms.Resize(256, interpolation=_pil_constants.LINEAR),
177+
transforms.Resize(256, interpolation=Image.LINEAR),
179178
transforms.PILToTensor(),
180179
transforms.ConvertImageDtype(dtype=torch.float),
181180
]
@@ -1533,10 +1532,10 @@ def test_ten_crop(should_vflip, single_dim):
15331532
five_crop.__repr__()
15341533

15351534
if should_vflip:
1536-
vflipped_img = img.transpose(_pil_constants.FLIP_TOP_BOTTOM)
1535+
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
15371536
expected_output += five_crop(vflipped_img)
15381537
else:
1539-
hflipped_img = img.transpose(_pil_constants.FLIP_LEFT_RIGHT)
1538+
hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
15401539
expected_output += five_crop(hflipped_img)
15411540

15421541
assert len(results) == 10
@@ -1883,6 +1882,9 @@ def test_random_rotation():
18831882
# Checking if RandomRotation can be printed as string
18841883
t.__repr__()
18851884

1885+
t = transforms.RandomRotation((-10, 10), interpolation=Image.BILINEAR)
1886+
assert t.interpolation == transforms.InterpolationMode.BILINEAR
1887+
18861888

18871889
def test_random_rotation_error():
18881890
# assert fill being either a Sequence or a Number
@@ -2212,6 +2214,9 @@ def test_random_affine():
22122214
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
22132215
assert "bilinear" in t.__repr__()
22142216

2217+
t = transforms.RandomAffine(10, interpolation=Image.BILINEAR)
2218+
assert t.interpolation == transforms.InterpolationMode.BILINEAR
2219+
22152220

22162221
def test_elastic_transformation():
22172222
with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
@@ -2228,9 +2233,8 @@ def test_elastic_transformation():
22282233
with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
22292234
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])
22302235

2231-
with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
2232-
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2)
2233-
assert t.interpolation == transforms.InterpolationMode.BILINEAR
2236+
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=Image.BILINEAR)
2237+
assert t.interpolation == transforms.InterpolationMode.BILINEAR
22342238

22352239
with pytest.raises(TypeError, match=r"fill should be int or float"):
22362240
transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})

test/test_transforms_tensor.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import warnings
44

55
import numpy as np
6+
import PIL.Image
67
import pytest
78
import torch
8-
import torchvision.transforms._pil_constants as _pil_constants
99
from common_utils import (
1010
_assert_approx_equal_tensor_to_pil,
1111
_assert_equal_tensor_to_pil,
@@ -657,13 +657,13 @@ def shear(pil_img, level, mode, resample):
657657
matrix = (1, level, 0, 0, 1, 0)
658658
elif mode == "Y":
659659
matrix = (1, 0, 0, level, 1, 0)
660-
return pil_img.transform((image_size, image_size), _pil_constants.AFFINE, matrix, resample=resample)
660+
return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample)
661661

662662
t_img, pil_img = _create_data(image_size, image_size)
663663

664664
resample_pil = {
665-
F.InterpolationMode.NEAREST: _pil_constants.NEAREST,
666-
F.InterpolationMode.BILINEAR: _pil_constants.BILINEAR,
665+
F.InterpolationMode.NEAREST: PIL.Image.NEAREST,
666+
F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR,
667667
}[interpolation]
668668

669669
level = 0.3

torchvision/transforms/_pil_constants.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

torchvision/transforms/functional.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,7 @@ def resize(
421421
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
422422
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
423423
supported.
424+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
424425
max_size (int, optional): The maximum allowed for the longer edge of
425426
the resized image: if the longer edge of the image is greater
426427
than ``max_size`` after being resized according to ``size``, then
@@ -454,8 +455,12 @@ def resize(
454455
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
455456
_log_api_usage_once(resize)
456457

457-
if not isinstance(interpolation, InterpolationMode):
458-
raise TypeError("Argument interpolation should be a InterpolationMode")
458+
if isinstance(interpolation, int):
459+
interpolation = _interpolation_modes_from_int(interpolation)
460+
elif not isinstance(interpolation, InterpolationMode):
461+
raise TypeError(
462+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
463+
)
459464

460465
if isinstance(size, (list, tuple)):
461466
if len(size) not in [1, 2]:
@@ -630,6 +635,7 @@ def resized_crop(
630635
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
631636
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
632637
supported.
638+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
633639
antialias (bool, optional): Whether to apply antialiasing.
634640
It only affects **tensors** with bilinear or bicubic modes and it is
635641
ignored otherwise: on PIL images, antialiasing is always applied on
@@ -726,6 +732,7 @@ def perspective(
726732
interpolation (InterpolationMode): Desired interpolation enum defined by
727733
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
728734
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
735+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
729736
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
730737
image. If given a number, the value is used for all bands respectively.
731738
@@ -741,8 +748,12 @@ def perspective(
741748

742749
coeffs = _get_perspective_coeffs(startpoints, endpoints)
743750

744-
if not isinstance(interpolation, InterpolationMode):
745-
raise TypeError("Argument interpolation should be a InterpolationMode")
751+
if isinstance(interpolation, int):
752+
interpolation = _interpolation_modes_from_int(interpolation)
753+
elif not isinstance(interpolation, InterpolationMode):
754+
raise TypeError(
755+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
756+
)
746757

747758
if not isinstance(img, torch.Tensor):
748759
pil_interpolation = pil_modes_mapping[interpolation]
@@ -1076,6 +1087,7 @@ def rotate(
10761087
interpolation (InterpolationMode): Desired interpolation enum defined by
10771088
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
10781089
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1090+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
10791091
expand (bool, optional): Optional expansion flag.
10801092
If true, expands the output image to make it large enough to hold the entire rotated image.
10811093
If false or omitted, make the output image the same size as the input image.
@@ -1097,15 +1109,19 @@ def rotate(
10971109
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
10981110
_log_api_usage_once(rotate)
10991111

1112+
if isinstance(interpolation, int):
1113+
interpolation = _interpolation_modes_from_int(interpolation)
1114+
elif not isinstance(interpolation, InterpolationMode):
1115+
raise TypeError(
1116+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
1117+
)
1118+
11001119
if not isinstance(angle, (int, float)):
11011120
raise TypeError("Argument angle should be int or float")
11021121

11031122
if center is not None and not isinstance(center, (list, tuple)):
11041123
raise TypeError("Argument center should be a sequence")
11051124

1106-
if not isinstance(interpolation, InterpolationMode):
1107-
raise TypeError("Argument interpolation should be a InterpolationMode")
1108-
11091125
if not isinstance(img, torch.Tensor):
11101126
pil_interpolation = pil_modes_mapping[interpolation]
11111127
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
@@ -1147,6 +1163,7 @@ def affine(
11471163
interpolation (InterpolationMode): Desired interpolation enum defined by
11481164
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
11491165
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
1166+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
11501167
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
11511168
image. If given a number, the value is used for all bands respectively.
11521169
@@ -1162,6 +1179,13 @@ def affine(
11621179
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
11631180
_log_api_usage_once(affine)
11641181

1182+
if isinstance(interpolation, int):
1183+
interpolation = _interpolation_modes_from_int(interpolation)
1184+
elif not isinstance(interpolation, InterpolationMode):
1185+
raise TypeError(
1186+
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
1187+
)
1188+
11651189
if not isinstance(angle, (int, float)):
11661190
raise TypeError("Argument angle should be int or float")
11671191

@@ -1177,9 +1201,6 @@ def affine(
11771201
if not isinstance(shear, (numbers.Number, (list, tuple))):
11781202
raise TypeError("Shear should be either a single value or a sequence of two values")
11791203

1180-
if not isinstance(interpolation, InterpolationMode):
1181-
raise TypeError("Argument interpolation should be a InterpolationMode")
1182-
11831204
if isinstance(angle, int):
11841205
angle = float(angle)
11851206

@@ -1524,7 +1545,7 @@ def elastic_transform(
15241545
interpolation (InterpolationMode): Desired interpolation enum defined by
15251546
:class:`torchvision.transforms.InterpolationMode`.
15261547
Default is ``InterpolationMode.BILINEAR``.
1527-
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
1548+
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
15281549
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
15291550
If a tuple of length 3, it is used to fill R, G, B channels respectively.
15301551
This value is only used when the padding_mode is constant.

0 commit comments

Comments
 (0)