Skip to content

Support integer values for interpolation in the prototype transforms #7248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1534,7 +1534,7 @@ def test__get_params(self, mocker):
assert int(spatial_size[1] * r_min) <= width <= int(spatial_size[1] * r_max)

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.ScaleJitter(
Expand Down Expand Up @@ -1581,7 +1581,7 @@ def test__get_params(self, min_size, max_size, mocker):
assert shorter in min_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomShortestSize(
Expand Down Expand Up @@ -1945,7 +1945,7 @@ def test__get_params(self):
assert min_size <= size < max_size

def test__transform(self, mocker):
interpolation_sentinel = mocker.MagicMock()
interpolation_sentinel = mocker.MagicMock(spec=InterpolationMode)
antialias_sentinel = mocker.MagicMock()

transform = transforms.RandomResize(
Expand Down
46 changes: 36 additions & 10 deletions test/test_prototype_transforms_consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ def __init__(
ArgsKwargs((32, 29)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((30, 27), interpolation=PIL.Image.NEAREST),
ArgsKwargs((35, 29), interpolation=PIL.Image.BILINEAR),
ArgsKwargs((34, 25), interpolation=PIL.Image.BICUBIC),
NotScriptableArgsKwargs(31, max_size=32),
ArgsKwargs([31], max_size=32),
NotScriptableArgsKwargs(30, max_size=100),
Expand Down Expand Up @@ -305,6 +308,8 @@ def __init__(
ArgsKwargs(25, ratio=(0.5, 1.5)),
ArgsKwargs((31, 28), interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs((33, 26), interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs((31, 28), interpolation=PIL.Image.NEAREST),
ArgsKwargs((33, 26), interpolation=PIL.Image.BICUBIC),
ArgsKwargs((29, 32), antialias=False),
ArgsKwargs((28, 31), antialias=True),
],
Expand Down Expand Up @@ -352,6 +357,8 @@ def __init__(
ArgsKwargs(sigma=(2.5, 3.9)),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(interpolation=prototype_transforms.InterpolationMode.BICUBIC),
ArgsKwargs(interpolation=PIL.Image.NEAREST),
ArgsKwargs(interpolation=PIL.Image.BICUBIC),
ArgsKwargs(fill=1),
],
# ElasticTransform needs larger images to avoid the needed internal padding being larger than the actual image
Expand Down Expand Up @@ -386,6 +393,7 @@ def __init__(
ArgsKwargs(degrees=0.0, shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=(-20.0, 10.0), translate=(0.4, 0.6), scale=(0.3, 0.8), shear=(4, 5, 4, 13)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.NEAREST),
ArgsKwargs(degrees=30.0, fill=1),
ArgsKwargs(degrees=30.0, fill=(2, 3, 4)),
ArgsKwargs(degrees=30.0, center=(0, 0)),
Expand Down Expand Up @@ -420,6 +428,7 @@ def __init__(
ArgsKwargs(p=1),
ArgsKwargs(p=1, distortion_scale=0.3),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=prototype_transforms.InterpolationMode.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.2, interpolation=PIL.Image.NEAREST),
ArgsKwargs(p=1, distortion_scale=0.1, fill=1),
ArgsKwargs(p=1, distortion_scale=0.4, fill=(1, 2, 3)),
],
Expand All @@ -432,6 +441,7 @@ def __init__(
ArgsKwargs(degrees=30.0),
ArgsKwargs(degrees=(-20.0, 10.0)),
ArgsKwargs(degrees=30.0, interpolation=prototype_transforms.InterpolationMode.BILINEAR),
ArgsKwargs(degrees=30.0, interpolation=PIL.Image.BILINEAR),
ArgsKwargs(degrees=30.0, expand=True),
ArgsKwargs(degrees=30.0, center=(0, 0)),
ArgsKwargs(degrees=30.0, fill=1),
Expand Down Expand Up @@ -851,7 +861,11 @@ class TestAATransforms:
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_randaug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.RandAugment(interpolation=interpolation, num_ops=1)
Expand Down Expand Up @@ -889,7 +903,11 @@ def test_randaug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_trivial_aug(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.TrivialAugmentWide(interpolation=interpolation)
Expand Down Expand Up @@ -937,7 +955,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_augmix(self, inpt, interpolation, mocker):
t_ref = legacy_transforms.AugMix(interpolation=interpolation, mixture_width=1, chain_depth=1)
Expand Down Expand Up @@ -986,7 +1008,11 @@ def test_augmix(self, inpt, interpolation, mocker):
)
@pytest.mark.parametrize(
"interpolation",
[prototype_transforms.InterpolationMode.NEAREST, prototype_transforms.InterpolationMode.BILINEAR],
[
prototype_transforms.InterpolationMode.NEAREST,
prototype_transforms.InterpolationMode.BILINEAR,
PIL.Image.NEAREST,
],
)
def test_aa(self, inpt, interpolation):
aa_policy = legacy_transforms.AutoAugmentPolicy("imagenet")
Expand Down Expand Up @@ -1264,13 +1290,13 @@ def test_random_resize_eval(self, mocker):
(legacy_F.convert_image_dtype, {}),
(legacy_F.to_pil_image, {}),
(legacy_F.normalize, {}),
(legacy_F.resize, {}),
(legacy_F.resize, {"interpolation"}),
(legacy_F.pad, {"padding", "fill"}),
(legacy_F.crop, {}),
(legacy_F.center_crop, {}),
(legacy_F.resized_crop, {}),
(legacy_F.resized_crop, {"interpolation"}),
(legacy_F.hflip, {}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill"}),
(legacy_F.perspective, {"startpoints", "endpoints", "fill", "interpolation"}),
(legacy_F.vflip, {}),
(legacy_F.five_crop, {}),
(legacy_F.ten_crop, {}),
Expand All @@ -1279,8 +1305,8 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_saturation, {}),
(legacy_F.adjust_hue, {}),
(legacy_F.adjust_gamma, {}),
(legacy_F.rotate, {"center", "fill"}),
(legacy_F.affine, {"angle", "translate", "center", "fill"}),
(legacy_F.rotate, {"center", "fill", "interpolation"}),
(legacy_F.affine, {"angle", "translate", "center", "fill", "interpolation"}),
(legacy_F.to_grayscale, {}),
(legacy_F.rgb_to_grayscale, {}),
(legacy_F.to_tensor, {}),
Expand All @@ -1292,7 +1318,7 @@ def test_random_resize_eval(self, mocker):
(legacy_F.adjust_sharpness, {}),
(legacy_F.autocontrast, {}),
(legacy_F.equalize, {}),
(legacy_F.elastic_transform, {"fill"}),
(legacy_F.elastic_transform, {"fill", "interpolation"}),
],
)
def test_dispatcher_signature_consistency(legacy_dispatcher, name_only_params):
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def vertical_flip(self) -> BoundingBox:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
Expand Down Expand Up @@ -107,7 +107,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> BoundingBox:
output, spatial_size = self._F.resized_crop_bounding_box(
Expand All @@ -133,7 +133,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -154,7 +154,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -174,7 +174,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> BoundingBox:
Expand All @@ -191,7 +191,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> BoundingBox:
output = self._F.elastic_bounding_box(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_datapoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def vertical_flip(self) -> Datapoint:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
Expand All @@ -162,7 +162,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Datapoint:
return self
Expand All @@ -178,7 +178,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -191,7 +191,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -201,7 +201,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Datapoint:
Expand All @@ -210,7 +210,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Datapoint:
return self
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def vertical_flip(self) -> Image:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
Expand All @@ -86,7 +86,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
antialias: Optional[Union[str, bool]] = "warn",
) -> Image:
output = self._F.resized_crop_image_tensor(
Expand All @@ -113,7 +113,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -129,7 +129,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Image:
Expand All @@ -149,7 +149,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Image:
Expand All @@ -166,7 +166,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
interpolation: Union[InterpolationMode, int] = InterpolationMode.BILINEAR,
fill: FillTypeJIT = None,
) -> Image:
output = self._F.elastic_image_tensor(
Expand Down
12 changes: 6 additions & 6 deletions torchvision/prototype/datapoints/_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def vertical_flip(self) -> Mask:
def resize( # type: ignore[override]
self,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
max_size: Optional[int] = None,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
Expand All @@ -75,7 +75,7 @@ def resized_crop(
height: int,
width: int,
size: List[int],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
antialias: Optional[Union[str, bool]] = "warn",
) -> Mask:
output = self._F.resized_crop_mask(self.as_subclass(torch.Tensor), top, left, height, width, size=size)
Expand All @@ -93,7 +93,7 @@ def pad(
def rotate(
self,
angle: float,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
expand: bool = False,
center: Optional[List[float]] = None,
fill: FillTypeJIT = None,
Expand All @@ -107,7 +107,7 @@ def affine(
translate: List[float],
scale: float,
shear: List[float],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
center: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -126,7 +126,7 @@ def perspective(
self,
startpoints: Optional[List[List[int]]],
endpoints: Optional[List[List[int]]],
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
coefficients: Optional[List[float]] = None,
) -> Mask:
Expand All @@ -138,7 +138,7 @@ def perspective(
def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
interpolation: Union[InterpolationMode, int] = InterpolationMode.NEAREST,
fill: FillTypeJIT = None,
) -> Mask:
output = self._F.elastic_mask(self.as_subclass(torch.Tensor), displacement, fill=fill)
Expand Down
Loading