Skip to content

Undeprecate int constants for interpolation #7241

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Feb 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 31 additions & 6 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Sequence

import numpy as np
import PIL.Image
import pytest
import torch
import torchvision.transforms as T
Expand Down Expand Up @@ -144,6 +145,12 @@ def test_rotate_batch(self, device, dt):
center = (20, 22)
_test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center)

def test_rotate_interpolation_type(self):
tensor, _ = _create_data(26, 26)
res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR)
res2 = F.rotate(tensor, 45, interpolation=BILINEAR)
assert_equal(res1, res2)


class TestAffine:

Expand Down Expand Up @@ -350,6 +357,14 @@ def test_batches(self, device, dt):

_test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0])

@pytest.mark.parametrize("device", cpu_and_gpu())
def test_interpolation_type(self, device):
tensor, pil_img = _create_data(26, 26, device=device)

res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR)
res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR)
assert_equal(res1, res2)


def _get_data_dims_and_points_for_perspective():
# Ideally we would parametrize independently over data dims and points, but
Expand Down Expand Up @@ -448,6 +463,16 @@ def test_perspective_batch(device, dims_and_points, dt):
)


def test_perspective_interpolation_type():
spoints = [[0, 0], [33, 0], [33, 25], [0, 25]]
epoints = [[3, 2], [32, 3], [30, 24], [2, 25]]
tensor = torch.randint(0, 256, (3, 26, 26))

res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR)
res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR)
assert_equal(res1, res2)


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -489,9 +514,7 @@ def test_resize(device, dt, size, max_size, interpolation):

assert resized_tensor.size()[1:] == resized_pil_img.size[::-1]

if interpolation not in [
NEAREST,
]:
if interpolation != NEAREST:
# We can not check values if mode = NEAREST, as results are different
# E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]]
# E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]]
Expand All @@ -504,9 +527,7 @@ def test_resize(device, dt, size, max_size, interpolation):
_assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0)

if isinstance(size, int):
script_size = [
size,
]
script_size = [size]
else:
script_size = size

Expand All @@ -523,6 +544,10 @@ def test_resize_asserts(device):

tensor, pil_img = _create_data(26, 36, device=device)

res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR)
res2 = F.resize(tensor, size=32, interpolation=BILINEAR)
assert_equal(res1, res2)

for img in (tensor, pil_img):
exp_msg = "max_size should only be passed if size specifies the length of the smaller edge"
with pytest.raises(ValueError, match=exp_msg):
Expand Down
3 changes: 1 addition & 2 deletions test/test_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,12 @@ def forward(self_module, images, features):
def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor:
import os

import torchvision.transforms._pil_constants as _pil_constants
from PIL import Image
from torchvision.transforms import functional as F

data_dir = os.path.join(os.path.dirname(__file__), "assets")
path = os.path.join(data_dir, *rel_path.split("/"))
image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR)
image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR)

return F.convert_image_dtype(F.pil_to_tensor(image))

Expand Down
18 changes: 11 additions & 7 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import pytest
import torch
import torchvision.transforms as transforms
import torchvision.transforms._pil_constants as _pil_constants
import torchvision.transforms.functional as F
import torchvision.transforms.functional_tensor as F_t
from PIL import Image
Expand Down Expand Up @@ -175,7 +174,7 @@ def test_accimage_pil_to_tensor(self):
def test_accimage_resize(self):
trans = transforms.Compose(
[
transforms.Resize(256, interpolation=_pil_constants.LINEAR),
transforms.Resize(256, interpolation=Image.LINEAR),
transforms.PILToTensor(),
transforms.ConvertImageDtype(dtype=torch.float),
]
Expand Down Expand Up @@ -1533,10 +1532,10 @@ def test_ten_crop(should_vflip, single_dim):
five_crop.__repr__()

if should_vflip:
vflipped_img = img.transpose(_pil_constants.FLIP_TOP_BOTTOM)
vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM)
expected_output += five_crop(vflipped_img)
else:
hflipped_img = img.transpose(_pil_constants.FLIP_LEFT_RIGHT)
hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT)
expected_output += five_crop(hflipped_img)

assert len(results) == 10
Expand Down Expand Up @@ -1883,6 +1882,9 @@ def test_random_rotation():
# Checking if RandomRotation can be printed as string
t.__repr__()

t = transforms.RandomRotation((-10, 10), interpolation=Image.BILINEAR)
assert t.interpolation == transforms.InterpolationMode.BILINEAR


def test_random_rotation_error():
# assert fill being either a Sequence or a Number
Expand Down Expand Up @@ -2212,6 +2214,9 @@ def test_random_affine():
t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR)
assert "bilinear" in t.__repr__()

t = transforms.RandomAffine(10, interpolation=Image.BILINEAR)
assert t.interpolation == transforms.InterpolationMode.BILINEAR


def test_elastic_transformation():
with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"):
Expand All @@ -2228,9 +2233,8 @@ def test_elastic_transformation():
with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"):
transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0])

with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"):
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2)
assert t.interpolation == transforms.InterpolationMode.BILINEAR
t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=Image.BILINEAR)
assert t.interpolation == transforms.InterpolationMode.BILINEAR

with pytest.raises(TypeError, match=r"fill should be int or float"):
transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={})
Expand Down
8 changes: 4 additions & 4 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
import warnings

import numpy as np
import PIL.Image
import pytest
import torch
import torchvision.transforms._pil_constants as _pil_constants
from common_utils import (
_assert_approx_equal_tensor_to_pil,
_assert_equal_tensor_to_pil,
Expand Down Expand Up @@ -657,13 +657,13 @@ def shear(pil_img, level, mode, resample):
matrix = (1, level, 0, 0, 1, 0)
elif mode == "Y":
matrix = (1, 0, 0, level, 1, 0)
return pil_img.transform((image_size, image_size), _pil_constants.AFFINE, matrix, resample=resample)
return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample)

t_img, pil_img = _create_data(image_size, image_size)

resample_pil = {
F.InterpolationMode.NEAREST: _pil_constants.NEAREST,
F.InterpolationMode.BILINEAR: _pil_constants.BILINEAR,
F.InterpolationMode.NEAREST: PIL.Image.NEAREST,
F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR,
}[interpolation]

level = 0.3
Expand Down
25 changes: 0 additions & 25 deletions torchvision/transforms/_pil_constants.py

This file was deleted.

43 changes: 32 additions & 11 deletions torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,7 @@ def resize(
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
max_size (int, optional): The maximum allowed for the longer edge of
the resized image: if the longer edge of the image is greater
than ``max_size`` after being resized according to ``size``, then
Expand Down Expand Up @@ -454,8 +455,12 @@ def resize(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(resize)

if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)

if isinstance(size, (list, tuple)):
if len(size) not in [1, 2]:
Expand Down Expand Up @@ -630,6 +635,7 @@ def resized_crop(
Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``,
``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are
supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
antialias (bool, optional): Whether to apply antialiasing.
It only affects **tensors** with bilinear or bicubic modes and it is
ignored otherwise: on PIL images, antialiasing is always applied on
Expand Down Expand Up @@ -726,6 +732,7 @@ def perspective(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.

Expand All @@ -741,8 +748,12 @@ def perspective(

coeffs = _get_perspective_coeffs(startpoints, endpoints)

if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")
if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
Expand Down Expand Up @@ -1076,6 +1087,7 @@ def rotate(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
expand (bool, optional): Optional expansion flag.
If true, expands the output image to make it large enough to hold the entire rotated image.
If false or omitted, make the output image the same size as the input image.
Expand All @@ -1097,15 +1109,19 @@ def rotate(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(rotate)

if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")

if center is not None and not isinstance(center, (list, tuple)):
raise TypeError("Argument center should be a sequence")

if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if not isinstance(img, torch.Tensor):
pil_interpolation = pil_modes_mapping[interpolation]
return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill)
Expand Down Expand Up @@ -1147,6 +1163,7 @@ def affine(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``.
If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (sequence or number, optional): Pixel fill value for the area outside the transformed
image. If given a number, the value is used for all bands respectively.

Expand All @@ -1162,6 +1179,13 @@ def affine(
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(affine)

if isinstance(interpolation, int):
interpolation = _interpolation_modes_from_int(interpolation)
elif not isinstance(interpolation, InterpolationMode):
raise TypeError(
"Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant"
)

if not isinstance(angle, (int, float)):
raise TypeError("Argument angle should be int or float")

Expand All @@ -1177,9 +1201,6 @@ def affine(
if not isinstance(shear, (numbers.Number, (list, tuple))):
raise TypeError("Shear should be either a single value or a sequence of two values")

if not isinstance(interpolation, InterpolationMode):
raise TypeError("Argument interpolation should be a InterpolationMode")

if isinstance(angle, int):
angle = float(angle)

Expand Down Expand Up @@ -1524,7 +1545,7 @@ def elastic_transform(
interpolation (InterpolationMode): Desired interpolation enum defined by
:class:`torchvision.transforms.InterpolationMode`.
Default is ``InterpolationMode.BILINEAR``.
For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable.
The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well.
fill (number or str or tuple): Pixel fill value for constant fill. Default is 0.
If a tuple of length 3, it is used to fill R, G, B channels respectively.
This value is only used when the padding_mode is constant.
Expand Down
Loading