From 02d8f622ad811b4b026603ee1f5946fa76556cbc Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 08:13:29 +0100 Subject: [PATCH 1/9] undeprecate integer interpolate in transforms --- torchvision/transforms/_pil_constants.py | 25 --------------- torchvision/transforms/transforms.py | 40 ++++++++++++++---------- 2 files changed, 24 insertions(+), 41 deletions(-) delete mode 100644 torchvision/transforms/_pil_constants.py diff --git a/torchvision/transforms/_pil_constants.py b/torchvision/transforms/_pil_constants.py deleted file mode 100644 index 46f6ce5d24d..00000000000 --- a/torchvision/transforms/_pil_constants.py +++ /dev/null @@ -1,25 +0,0 @@ -from PIL import Image - -# See https://pillow.readthedocs.io/en/stable/releasenotes/9.1.0.html#deprecations -# TODO: Remove this file once PIL minimal version is >= 9.1 - -if hasattr(Image, "Resampling"): - BICUBIC = Image.Resampling.BICUBIC - BILINEAR = Image.Resampling.BILINEAR - LINEAR = Image.Resampling.BILINEAR - NEAREST = Image.Resampling.NEAREST - - AFFINE = Image.Transform.AFFINE - FLIP_LEFT_RIGHT = Image.Transpose.FLIP_LEFT_RIGHT - FLIP_TOP_BOTTOM = Image.Transpose.FLIP_TOP_BOTTOM - PERSPECTIVE = Image.Transform.PERSPECTIVE -else: - BICUBIC = Image.BICUBIC - BILINEAR = Image.BILINEAR - NEAREST = Image.NEAREST - LINEAR = Image.LINEAR - - AFFINE = Image.AFFINE - FLIP_LEFT_RIGHT = Image.FLIP_LEFT_RIGHT - FLIP_TOP_BOTTOM = Image.FLIP_TOP_BOTTOM - PERSPECTIVE = Image.PERSPECTIVE diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 88cc1c0d978..7b240c43131 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -298,6 +298,7 @@ class Resize(torch.nn.Module): :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. max_size (int, optional): The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater than ``max_size`` after being resized according to ``size``, then @@ -336,6 +337,9 @@ def __init__(self, size, interpolation=InterpolationMode.BILINEAR, max_size=None self.size = size self.max_size = max_size + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.antialias = antialias @@ -756,6 +760,7 @@ class RandomPerspective(torch.nn.Module): interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (sequence or number): Pixel fill value for the area outside the transformed image. Default is ``0``. If given a number, the value is used for all bands respectively. """ @@ -765,6 +770,9 @@ def __init__(self, distortion_scale=0.5, p=0.5, interpolation=InterpolationMode. _log_api_usage_once(self) self.p = p + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.distortion_scale = distortion_scale @@ -861,6 +869,7 @@ class RandomResizedCrop(torch.nn.Module): :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. antialias (bool, optional): Whether to apply antialiasing. It only affects **tensors** with bilinear or bicubic modes and it is ignored otherwise: on PIL images, antialiasing is always applied on @@ -900,9 +909,11 @@ def __init__( if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.interpolation = interpolation self.antialias = antialias - self.interpolation = interpolation self.scale = scale self.ratio = ratio @@ -1107,11 +1118,6 @@ def __init__(self, transformation_matrix, mean_vector): f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" ) - if transformation_matrix.dtype != mean_vector.dtype: - raise ValueError( - f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" - ) - self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1139,10 +1145,9 @@ def forward(self, tensor: Tensor) -> Tensor: ) flat_tensor = tensor.view(-1, n) - self.mean_vector - - transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) - transformed_tensor = torch.mm(flat_tensor, transformation_matrix) - return transformed_tensor.view(shape) + transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + tensor = transformed_tensor.view(shape) + return tensor def __repr__(self) -> str: s = ( @@ -1293,6 +1298,7 @@ class RandomRotation(torch.nn.Module): interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. expand (bool, optional): Optional expansion flag. If true, expands the output to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1310,6 +1316,9 @@ def __init__(self, degrees, interpolation=InterpolationMode.NEAREST, expand=Fals super().__init__() _log_api_usage_once(self) + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if center is not None: @@ -1393,6 +1402,7 @@ class RandomAffine(torch.nn.Module): interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (sequence or number): Pixel fill value for the area outside the transformed image. Default is ``0``. If given a number, the value is used for all bands respectively. center (sequence, optional): Optional center of rotation, (x, y). Origin is the upper left corner. @@ -1415,6 +1425,9 @@ def __init__( super().__init__() _log_api_usage_once(self) + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + self.degrees = _setup_angle(degrees, name="degrees", req_sizes=(2,)) if translate is not None: @@ -2039,7 +2052,7 @@ class ElasticTransform(torch.nn.Module): interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. - For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (sequence or number): Pixel fill value for the area outside the transformed image. Default is ``0``. If given a number, the value is used for all bands respectively. @@ -2080,12 +2093,7 @@ def __init__(self, alpha=50.0, sigma=5.0, interpolation=InterpolationMode.BILINE self.sigma = sigma - # Backward compatibility with integer value if isinstance(interpolation, int): - warnings.warn( - "Argument interpolation should be of type InterpolationMode instead of int. " - "Please, use InterpolationMode enum." - ) interpolation = _interpolation_modes_from_int(interpolation) self.interpolation = interpolation From f0067107cebd6e06fdb55e7028bfbdbd3cd8574c Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 08:13:44 +0100 Subject: [PATCH 2/9] remove pil constants helper --- torchvision/transforms/functional_pil.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/torchvision/transforms/functional_pil.py b/torchvision/transforms/functional_pil.py index a75c46b4958..120998d0072 100644 --- a/torchvision/transforms/functional_pil.py +++ b/torchvision/transforms/functional_pil.py @@ -9,7 +9,6 @@ import accimage except ImportError: accimage = None -from . import _pil_constants @torch.jit.unused @@ -54,7 +53,7 @@ def hflip(img: Image.Image) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - return img.transpose(_pil_constants.FLIP_LEFT_RIGHT) + return img.transpose(Image.FLIP_LEFT_RIGHT) @torch.jit.unused @@ -62,7 +61,7 @@ def vflip(img: Image.Image) -> Image.Image: if not _is_pil_image(img): raise TypeError(f"img should be PIL Image. Got {type(img)}") - return img.transpose(_pil_constants.FLIP_TOP_BOTTOM) + return img.transpose(Image.FLIP_TOP_BOTTOM) @torch.jit.unused @@ -240,7 +239,7 @@ def crop( def resize( img: Image.Image, size: Union[List[int], int], - interpolation: int = _pil_constants.BILINEAR, + interpolation: int = Image.BILINEAR, ) -> Image.Image: if not _is_pil_image(img): @@ -284,7 +283,7 @@ def _parse_fill( def affine( img: Image.Image, matrix: List[float], - interpolation: int = _pil_constants.NEAREST, + interpolation: int = Image.NEAREST, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image.Image: @@ -293,14 +292,14 @@ def affine( output_size = img.size opts = _parse_fill(fill, img) - return img.transform(output_size, _pil_constants.AFFINE, matrix, interpolation, **opts) + return img.transform(output_size, Image.AFFINE, matrix, interpolation, **opts) @torch.jit.unused def rotate( img: Image.Image, angle: float, - interpolation: int = _pil_constants.NEAREST, + interpolation: int = Image.NEAREST, expand: bool = False, center: Optional[Tuple[int, int]] = None, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, @@ -317,7 +316,7 @@ def rotate( def perspective( img: Image.Image, perspective_coeffs: List[float], - interpolation: int = _pil_constants.BICUBIC, + interpolation: int = Image.BICUBIC, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> Image.Image: @@ -326,7 +325,7 @@ def perspective( opts = _parse_fill(fill, img) - return img.transform(img.size, _pil_constants.PERSPECTIVE, perspective_coeffs, interpolation, **opts) + return img.transform(img.size, Image.PERSPECTIVE, perspective_coeffs, interpolation, **opts) @torch.jit.unused From 4b3f9a900c7e7c6c2a1cda4de01036cf2fdc850b Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 08:18:59 +0100 Subject: [PATCH 3/9] undeprecate interger interpolate in functional --- torchvision/transforms/functional.py | 47 +++++++++++++++++++--------- 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 76c79df93d1..84e33eb9eb3 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -421,6 +421,7 @@ def resize( Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. max_size (int, optional): The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater than ``max_size`` after being resized according to ``size``, then @@ -454,8 +455,12 @@ def resize( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(resize) - if not isinstance(interpolation, InterpolationMode): - raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + ) if isinstance(size, (list, tuple)): if len(size) not in [1, 2]: @@ -476,8 +481,6 @@ def resize( if (image_height, image_width) == output_size: return img - antialias = _check_antialias(img, antialias, interpolation) - if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") @@ -610,7 +613,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[Union[str, bool]] = "warn", + antialias: Optional[bool] = None, ) -> Tensor: """Crop the given image and resize it to desired size. If the image is torch Tensor, it is expected @@ -630,6 +633,7 @@ def resized_crop( Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.NEAREST_EXACT``, ``InterpolationMode.BILINEAR`` and ``InterpolationMode.BICUBIC`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. antialias (bool, optional): Whether to apply antialiasing. It only affects **tensors** with bilinear or bicubic modes and it is ignored otherwise: on PIL images, antialiasing is always applied on @@ -726,6 +730,7 @@ def perspective( interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. @@ -741,8 +746,12 @@ def perspective( coeffs = _get_perspective_coeffs(startpoints, endpoints) - if not isinstance(interpolation, InterpolationMode): - raise TypeError("Argument interpolation should be a InterpolationMode") + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + ) if not isinstance(img, torch.Tensor): pil_interpolation = pil_modes_mapping[interpolation] @@ -1076,6 +1085,7 @@ def rotate( interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. expand (bool, optional): Optional expansion flag. If true, expands the output image to make it large enough to hold the entire rotated image. If false or omitted, make the output image the same size as the input image. @@ -1097,15 +1107,19 @@ def rotate( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(rotate) + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + ) + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") if center is not None and not isinstance(center, (list, tuple)): raise TypeError("Argument center should be a sequence") - if not isinstance(interpolation, InterpolationMode): - raise TypeError("Argument interpolation should be a InterpolationMode") - if not isinstance(img, torch.Tensor): pil_interpolation = pil_modes_mapping[interpolation] return F_pil.rotate(img, angle=angle, interpolation=pil_interpolation, expand=expand, center=center, fill=fill) @@ -1147,6 +1161,7 @@ def affine( interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.NEAREST``. If input is Tensor, only ``InterpolationMode.NEAREST``, ``InterpolationMode.BILINEAR`` are supported. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (sequence or number, optional): Pixel fill value for the area outside the transformed image. If given a number, the value is used for all bands respectively. @@ -1162,6 +1177,13 @@ def affine( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(affine) + if isinstance(interpolation, int): + interpolation = _interpolation_modes_from_int(interpolation) + elif not isinstance(interpolation, InterpolationMode): + raise TypeError( + "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + ) + if not isinstance(angle, (int, float)): raise TypeError("Argument angle should be int or float") @@ -1177,9 +1199,6 @@ def affine( if not isinstance(shear, (numbers.Number, (list, tuple))): raise TypeError("Shear should be either a single value or a sequence of two values") - if not isinstance(interpolation, InterpolationMode): - raise TypeError("Argument interpolation should be a InterpolationMode") - if isinstance(angle, int): angle = float(angle) @@ -1524,7 +1543,7 @@ def elastic_transform( interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. - For backward compatibility integer values (e.g. ``PIL.Image.NEAREST``) are still acceptable. + The corresponding Pillow integer constants, e.g. ``PIL.Image.BILINEAR`` are accepted as well. fill (number or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of length 3, it is used to fill R, G, B channels respectively. This value is only used when the padding_mode is constant. From a2dc528937958a0a58ba90d9bd02ac842816ed33 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 08:24:13 +0100 Subject: [PATCH 4/9] put tests back --- test/test_functional_tensor.py | 29 +++++++++++++++++++++++++++++ test/test_transforms.py | 17 ++++++----------- 2 files changed, 35 insertions(+), 11 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 58ba98bdf74..d87551a05a8 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -144,6 +144,12 @@ def test_rotate_batch(self, device, dt): center = (20, 22) _test_fn_on_batch(batch_tensors, F.rotate, angle=32, interpolation=NEAREST, expand=True, center=center) + def test_rotate_interpolation_type(self): + tensor, _ = _create_data(26, 26) + res1 = F.rotate(tensor, 45, interpolation=2) + res2 = F.rotate(tensor, 45, interpolation=BILINEAR) + assert_equal(res1, res2) + class TestAffine: @@ -350,6 +356,14 @@ def test_batches(self, device, dt): _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_warnings(self, device): + tensor, pil_img = _create_data(26, 26, device=device) + + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) + assert_equal(res1, res2) + def _get_data_dims_and_points_for_perspective(): # Ideally we would parametrize independently over data dims and points, but @@ -448,6 +462,17 @@ def test_perspective_batch(device, dims_and_points, dt): ) +def test_perspective_interpolation_warning(): + # assert changed type warning + spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] + epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] + tensor = torch.randint(0, 256, (3, 26, 26)) + + res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) + res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) + assert_equal(res1, res2) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( @@ -523,6 +548,10 @@ def test_resize_asserts(device): tensor, pil_img = _create_data(26, 36, device=device) + res1 = F.resize(tensor, size=32, interpolation=2) + res2 = F.resize(tensor, size=32, interpolation=BILINEAR) + assert_equal(res1, res2) + for img in (tensor, pil_img): exp_msg = "max_size should only be passed if size specifies the length of the smaller edge" with pytest.raises(ValueError, match=exp_msg): diff --git a/test/test_transforms.py b/test/test_transforms.py index a9074909cf0..2c6441d60a3 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,7 +2,6 @@ import os import random import re -import warnings from functools import partial import numpy as np @@ -440,16 +439,6 @@ def test_resize_antialias_error(): t(img) -def test_resize_antialias_default_warning(): - - img = Image.new("RGB", size=(10, 10), color=127) - # We make sure we don't warn for PIL images since the default behaviour doesn't change - with warnings.catch_warnings(): - warnings.simplefilter("error") - transforms.Resize((20, 20))(img) - transforms.RandomResizedCrop((20, 20))(img) - - @pytest.mark.parametrize("height, width", ((32, 64), (64, 32))) def test_resize_size_equals_small_edge_size(height, width): # Non-regression test for https://github.com/pytorch/vision/issues/5405 @@ -1883,6 +1872,9 @@ def test_random_rotation(): # Checking if RandomRotation can be printed as string t.__repr__() + t = transforms.RandomRotation((-10, 10), interpolation=2) + assert t.interpolation == transforms.InterpolationMode.BILINEAR + def test_random_rotation_error(): # assert fill being either a Sequence or a Number @@ -2212,6 +2204,9 @@ def test_random_affine(): t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR) assert "bilinear" in t.__repr__() + t = transforms.RandomAffine(10, interpolation=2) + assert t.interpolation == transforms.InterpolationMode.BILINEAR + def test_elastic_transformation(): with pytest.raises(TypeError, match=r"alpha should be float or a sequence of floats"): From fce283ae7791232f3e47e4ad15d92444ddf7b4db Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 09:37:18 +0100 Subject: [PATCH 5/9] cleanup --- test/test_functional_tensor.py | 12 ++++-------- test/test_transforms.py | 21 +++++++++++++++------ torchvision/transforms/functional.py | 4 +++- torchvision/transforms/transforms.py | 8 +++++++- 4 files changed, 29 insertions(+), 16 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index d87551a05a8..46708f1e3ba 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -357,7 +357,7 @@ def test_batches(self, device, dt): _test_fn_on_batch(batch_tensors, F.affine, angle=-43, translate=[-3, 4], scale=1.2, shear=[4.0, 5.0]) @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_warnings(self, device): + def test_interpolation_type(self, device): tensor, pil_img = _create_data(26, 26, device=device) res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) @@ -462,7 +462,7 @@ def test_perspective_batch(device, dims_and_points, dt): ) -def test_perspective_interpolation_warning(): +def test_perspective_interpolation_type(): # assert changed type warning spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] @@ -514,9 +514,7 @@ def test_resize(device, dt, size, max_size, interpolation): assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] - if interpolation not in [ - NEAREST, - ]: + if interpolation not in [NEAREST]: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] @@ -529,9 +527,7 @@ def test_resize(device, dt, size, max_size, interpolation): _assert_approx_equal_tensor_to_pil(resized_tensor_f, resized_pil_img, tol=8.0) if isinstance(size, int): - script_size = [ - size, - ] + script_size = [size] else: script_size = size diff --git a/test/test_transforms.py b/test/test_transforms.py index 2c6441d60a3..cbae55a603d 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -2,13 +2,13 @@ import os import random import re +import warnings from functools import partial import numpy as np import pytest import torch import torchvision.transforms as transforms -import torchvision.transforms._pil_constants as _pil_constants import torchvision.transforms.functional as F import torchvision.transforms.functional_tensor as F_t from PIL import Image @@ -439,6 +439,16 @@ def test_resize_antialias_error(): t(img) +def test_resize_antialias_default_warning(): + + img = Image.new("RGB", size=(10, 10), color=127) + # We make sure we don't warn for PIL images since the default behaviour doesn't change + with warnings.catch_warnings(): + warnings.simplefilter("error") + transforms.Resize((20, 20))(img) + transforms.RandomResizedCrop((20, 20))(img) + + @pytest.mark.parametrize("height, width", ((32, 64), (64, 32))) def test_resize_size_equals_small_edge_size(height, width): # Non-regression test for https://github.com/pytorch/vision/issues/5405 @@ -1522,10 +1532,10 @@ def test_ten_crop(should_vflip, single_dim): five_crop.__repr__() if should_vflip: - vflipped_img = img.transpose(_pil_constants.FLIP_TOP_BOTTOM) + vflipped_img = img.transpose(Image.FLIP_TOP_BOTTOM) expected_output += five_crop(vflipped_img) else: - hflipped_img = img.transpose(_pil_constants.FLIP_LEFT_RIGHT) + hflipped_img = img.transpose(Image.FLIP_LEFT_RIGHT) expected_output += five_crop(hflipped_img) assert len(results) == 10 @@ -2223,9 +2233,8 @@ def test_elastic_transformation(): with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"): transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0]) - with pytest.warns(UserWarning, match=r"Argument interpolation should be of type InterpolationMode"): - t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2) - assert t.interpolation == transforms.InterpolationMode.BILINEAR + t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2) + assert t.interpolation == transforms.InterpolationMode.BILINEAR with pytest.raises(TypeError, match=r"fill should be int or float"): transforms.ElasticTransform(alpha=1.0, sigma=1.0, fill={}) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 84e33eb9eb3..04f9a86f6fe 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -481,6 +481,8 @@ def resize( if (image_height, image_width) == output_size: return img + antialias = _check_antialias(img, antialias, interpolation) + if not isinstance(img, torch.Tensor): if antialias is not None and not antialias: warnings.warn("Anti-alias option is always applied for PIL Image input. Argument antialias is ignored.") @@ -613,7 +615,7 @@ def resized_crop( width: int, size: List[int], interpolation: InterpolationMode = InterpolationMode.BILINEAR, - antialias: Optional[bool] = None, + antialias: Optional[Union[str, bool]] = "warn", ) -> Tensor: """Crop the given image and resize it to desired size. If the image is torch Tensor, it is expected diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 7b240c43131..d7858353be9 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1118,6 +1118,11 @@ def __init__(self, transformation_matrix, mean_vector): f"Input tensors should be on the same device. Got {transformation_matrix.device} and {mean_vector.device}" ) + if transformation_matrix.dtype != mean_vector.dtype: + raise ValueError( + f"Input tensors should have the same dtype. Got {transformation_matrix.dtype} and {mean_vector.dtype}" + ) + self.transformation_matrix = transformation_matrix self.mean_vector = mean_vector @@ -1145,7 +1150,8 @@ def forward(self, tensor: Tensor) -> Tensor: ) flat_tensor = tensor.view(-1, n) - self.mean_vector - transformed_tensor = torch.mm(flat_tensor, self.transformation_matrix) + transformation_matrix = self.transformation_matrix.to(flat_tensor.dtype) + transformed_tensor = torch.mm(flat_tensor, transformation_matrix) tensor = transformed_tensor.view(shape) return tensor From 055c05f79b306719dda2e1b26f0ffb2306ebc057 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 09:39:16 +0100 Subject: [PATCH 6/9] remove outdated comment --- test/test_functional_tensor.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 46708f1e3ba..985389a4bfe 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -463,7 +463,6 @@ def test_perspective_batch(device, dims_and_points, dt): def test_perspective_interpolation_type(): - # assert changed type warning spoints = [[0, 0], [33, 0], [33, 25], [0, 25]] epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] tensor = torch.randint(0, 256, (3, 26, 26)) From 182ec34fc646379bbd507c5f4616736797537882 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 10:30:08 +0100 Subject: [PATCH 7/9] remove remaining _pil_constant usages --- test/test_onnx.py | 3 +-- test/test_transforms.py | 2 +- test/test_transforms_tensor.py | 8 ++++---- 3 files changed, 6 insertions(+), 7 deletions(-) diff --git a/test/test_onnx.py b/test/test_onnx.py index 09c73accc3b..0af76072e9e 100644 --- a/test/test_onnx.py +++ b/test/test_onnx.py @@ -407,13 +407,12 @@ def forward(self_module, images, features): def get_image(self, rel_path: str, size: Tuple[int, int]) -> torch.Tensor: import os - import torchvision.transforms._pil_constants as _pil_constants from PIL import Image from torchvision.transforms import functional as F data_dir = os.path.join(os.path.dirname(__file__), "assets") path = os.path.join(data_dir, *rel_path.split("/")) - image = Image.open(path).convert("RGB").resize(size, _pil_constants.BILINEAR) + image = Image.open(path).convert("RGB").resize(size, Image.BILINEAR) return F.convert_image_dtype(F.pil_to_tensor(image)) diff --git a/test/test_transforms.py b/test/test_transforms.py index cbae55a603d..de8ca02fba4 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -174,7 +174,7 @@ def test_accimage_pil_to_tensor(self): def test_accimage_resize(self): trans = transforms.Compose( [ - transforms.Resize(256, interpolation=_pil_constants.LINEAR), + transforms.Resize(256, interpolation=Image.LINEAR), transforms.PILToTensor(), transforms.ConvertImageDtype(dtype=torch.float), ] diff --git a/test/test_transforms_tensor.py b/test/test_transforms_tensor.py index b58e2420338..ef26f393db2 100644 --- a/test/test_transforms_tensor.py +++ b/test/test_transforms_tensor.py @@ -3,9 +3,9 @@ import warnings import numpy as np +import PIL.Image import pytest import torch -import torchvision.transforms._pil_constants as _pil_constants from common_utils import ( _assert_approx_equal_tensor_to_pil, _assert_equal_tensor_to_pil, @@ -657,13 +657,13 @@ def shear(pil_img, level, mode, resample): matrix = (1, level, 0, 0, 1, 0) elif mode == "Y": matrix = (1, 0, 0, level, 1, 0) - return pil_img.transform((image_size, image_size), _pil_constants.AFFINE, matrix, resample=resample) + return pil_img.transform((image_size, image_size), PIL.Image.AFFINE, matrix, resample=resample) t_img, pil_img = _create_data(image_size, image_size) resample_pil = { - F.InterpolationMode.NEAREST: _pil_constants.NEAREST, - F.InterpolationMode.BILINEAR: _pil_constants.BILINEAR, + F.InterpolationMode.NEAREST: PIL.Image.NEAREST, + F.InterpolationMode.BILINEAR: PIL.Image.BILINEAR, }[interpolation] level = 0.3 From 0c1889362db11b92237f7e6442b23125ecc2e4a1 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 10:32:27 +0100 Subject: [PATCH 8/9] use named constants --- test/test_functional_tensor.py | 9 +++++---- test/test_transforms.py | 6 +++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 985389a4bfe..c40131c525b 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -7,6 +7,7 @@ from typing import Sequence import numpy as np +import PIL.Image import pytest import torch import torchvision.transforms as T @@ -146,7 +147,7 @@ def test_rotate_batch(self, device, dt): def test_rotate_interpolation_type(self): tensor, _ = _create_data(26, 26) - res1 = F.rotate(tensor, 45, interpolation=2) + res1 = F.rotate(tensor, 45, interpolation=PIL.Image.BILINEAR) res2 = F.rotate(tensor, 45, interpolation=BILINEAR) assert_equal(res1, res2) @@ -360,7 +361,7 @@ def test_batches(self, device, dt): def test_interpolation_type(self, device): tensor, pil_img = _create_data(26, 26, device=device) - res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=2) + res1 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=PIL.Image.BILINEAR) res2 = F.affine(tensor, 45, translate=[0, 0], scale=1.0, shear=[0.0, 0.0], interpolation=BILINEAR) assert_equal(res1, res2) @@ -467,7 +468,7 @@ def test_perspective_interpolation_type(): epoints = [[3, 2], [32, 3], [30, 24], [2, 25]] tensor = torch.randint(0, 256, (3, 26, 26)) - res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=2) + res1 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=PIL.Image.BILINEAR) res2 = F.perspective(tensor, startpoints=spoints, endpoints=epoints, interpolation=BILINEAR) assert_equal(res1, res2) @@ -543,7 +544,7 @@ def test_resize_asserts(device): tensor, pil_img = _create_data(26, 36, device=device) - res1 = F.resize(tensor, size=32, interpolation=2) + res1 = F.resize(tensor, size=32, interpolation=PIL.Image.BILINEAR) res2 = F.resize(tensor, size=32, interpolation=BILINEAR) assert_equal(res1, res2) diff --git a/test/test_transforms.py b/test/test_transforms.py index de8ca02fba4..57e61bbad70 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -1882,7 +1882,7 @@ def test_random_rotation(): # Checking if RandomRotation can be printed as string t.__repr__() - t = transforms.RandomRotation((-10, 10), interpolation=2) + t = transforms.RandomRotation((-10, 10), interpolation=Image.BILINEAR) assert t.interpolation == transforms.InterpolationMode.BILINEAR @@ -2214,7 +2214,7 @@ def test_random_affine(): t = transforms.RandomAffine(10, interpolation=transforms.InterpolationMode.BILINEAR) assert "bilinear" in t.__repr__() - t = transforms.RandomAffine(10, interpolation=2) + t = transforms.RandomAffine(10, interpolation=Image.BILINEAR) assert t.interpolation == transforms.InterpolationMode.BILINEAR @@ -2233,7 +2233,7 @@ def test_elastic_transformation(): with pytest.raises(ValueError, match=r"sigma is a sequence its length should be 2"): transforms.ElasticTransform(alpha=2.0, sigma=[1.0, 0.0, 1.0]) - t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=2) + t = transforms.transforms.ElasticTransform(alpha=2.0, sigma=2.0, interpolation=Image.BILINEAR) assert t.interpolation == transforms.InterpolationMode.BILINEAR with pytest.raises(TypeError, match=r"fill should be int or float"): From 61140a100dc1d28b9d22c0915dc7497eac164945 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Tue, 14 Feb 2023 10:34:08 +0100 Subject: [PATCH 9/9] address comments --- test/test_functional_tensor.py | 2 +- torchvision/transforms/functional.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index c40131c525b..3e0ca881acb 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -514,7 +514,7 @@ def test_resize(device, dt, size, max_size, interpolation): assert resized_tensor.size()[1:] == resized_pil_img.size[::-1] - if interpolation not in [NEAREST]: + if interpolation != NEAREST: # We can not check values if mode = NEAREST, as results are different # E.g. resized_tensor = [[a, a, b, c, d, d, e, ...]] # E.g. resized_pil_img = [[a, b, c, c, d, e, f, ...]] diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 04f9a86f6fe..29940837a41 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -459,7 +459,7 @@ def resize( interpolation = _interpolation_modes_from_int(interpolation) elif not isinstance(interpolation, InterpolationMode): raise TypeError( - "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" ) if isinstance(size, (list, tuple)): @@ -752,7 +752,7 @@ def perspective( interpolation = _interpolation_modes_from_int(interpolation) elif not isinstance(interpolation, InterpolationMode): raise TypeError( - "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" ) if not isinstance(img, torch.Tensor): @@ -1113,7 +1113,7 @@ def rotate( interpolation = _interpolation_modes_from_int(interpolation) elif not isinstance(interpolation, InterpolationMode): raise TypeError( - "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" ) if not isinstance(angle, (int, float)): @@ -1183,7 +1183,7 @@ def affine( interpolation = _interpolation_modes_from_int(interpolation) elif not isinstance(interpolation, InterpolationMode): raise TypeError( - "Argument interpolation should be a InterpolationMode or an corresponding Pillow integer constant" + "Argument interpolation should be a InterpolationMode or a corresponding Pillow integer constant" ) if not isinstance(angle, (int, float)):