Skip to content

[proto] Added elastic transform and tests #6295

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Jul 22, 2022
18 changes: 13 additions & 5 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1352,16 +1352,24 @@ def test_ten_crop(device):
assert_equal(transformed_batch, s_transformed_batch)


def test_elastic_transform_asserts():
with pytest.raises(TypeError, match="Argument displacement should be a Tensor"):
_ = F.elastic_transform("abc", displacement=None)

with pytest.raises(TypeError, match="img should be PIL Image or Tensor"):
_ = F.elastic_transform("abc", displacement=torch.rand(1))

img_tensor = torch.rand(1, 3, 32, 24)
with pytest.raises(ValueError, match="Argument displacement shape should"):
_ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2))


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC])
@pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize(
"fill",
[
None,
[255, 255, 255],
(2.0,),
],
[None, [255, 255, 255], (2.0,)],
)
def test_elastic_transform_consistency(device, interpolation, dt, fill):
script_elastic_transform = torch.jit.script(F.elastic_transform)
Expand Down
75 changes: 75 additions & 0 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker):
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
inpt = mocker.MagicMock(spec=features.Image)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker):

# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
Expand Down Expand Up @@ -795,10 +800,80 @@ def test__transform(self, distortion_scale, mocker):
inpt.image_size = (24, 32)
# vfdev-5, Feature Request: let's store params as Transform attribute
# This could be also helpful for users
# Otherwise, we can mock transform._get_params
torch.manual_seed(12)
_ = transform(inpt)
torch.manual_seed(12)
torch.rand(1) # random apply changes random state
params = transform._get_params(inpt)

fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)


class TestElasticTransform:
def test_assertions(self):

with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"):
transforms.ElasticTransform({})

with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"):
transforms.ElasticTransform([1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="alpha should be a sequence of floats"):
transforms.ElasticTransform([1, 2])

with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
transforms.ElasticTransform(1.0, {})

with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"):
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])

with pytest.raises(ValueError, match="sigma should be a sequence of floats"):
transforms.ElasticTransform(1.0, [1, 2])

with pytest.raises(TypeError, match="Got inappropriate fill arg"):
transforms.ElasticTransform(1.0, 2.0, fill="abc")

def test__get_params(self, mocker):
alpha = 2.0
sigma = 3.0
transform = transforms.ElasticTransform(alpha, sigma)
image = mocker.MagicMock(spec=features.Image)
image.num_channels = 3
image.image_size = (24, 32)

params = transform._get_params(image)

h, w = image.image_size
displacement = params["displacement"]
assert displacement.shape == (1, h, w, 2)
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()

@pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]])
@pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]])
def test__transform(self, alpha, sigma, mocker):
interpolation = InterpolationMode.BILINEAR
fill = 12
transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation)

if isinstance(alpha, float):
assert transform.alpha == [alpha, alpha]
else:
assert transform.alpha == alpha

if isinstance(sigma, float):
assert transform.sigma == [sigma, sigma]
else:
assert transform.sigma == sigma

fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
inpt = mocker.MagicMock(spec=features.Image)
inpt.num_channels = 3
inpt.image_size = (24, 32)

# Let's mock transform._get_params to control the output:
transform._get_params = mocker.MagicMock()
_ = transform(inpt)
params = transform._get_params(inpt)
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
106 changes: 98 additions & 8 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def make_images(
yield make_image(size, color_space=color_space, dtype=dtype)

for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims):
yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype)
yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype)


def randint_with_tensor_bounds(arg1, arg2=None, **kwargs):
Expand Down Expand Up @@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype


def make_segmentation_masks(
image_sizes=((16, 16), (7, 33), (31, 9)),
sizes=((16, 16), (7, 33), (31, 9)),
dtypes=(torch.long,),
extra_dims=((), (4,), (2, 3)),
):
for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_)
for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims):
yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_)


class SampleInput:
Expand Down Expand Up @@ -533,6 +533,40 @@ def perspective_segmentation_mask():
)


@register_kernel_info_from_sample_inputs_fn
def elastic_image_tensor():
for image, fill in itertools.product(
make_images(extra_dims=((), (4,))),
[None, [128], [12.0]], # fill
):
h, w = image.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(image, displacement=displacement, fill=fill)


@register_kernel_info_from_sample_inputs_fn
def elastic_bounding_box():
for bounding_box in make_bounding_boxes():
h, w = bounding_box.image_size
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
bounding_box,
format=bounding_box.format,
displacement=displacement,
)


@register_kernel_info_from_sample_inputs_fn
def elastic_segmentation_mask():
for mask in make_segmentation_masks(extra_dims=((), (4,))):
h, w = mask.shape[-2:]
displacement = torch.rand(1, h, w, 2)
yield SampleInput(
mask,
displacement=displacement,
)


@register_kernel_info_from_sample_inputs_fn
def center_crop_image_tensor():
for mask, output_size in itertools.product(
Expand All @@ -553,7 +587,7 @@ def center_crop_bounding_box():
@register_kernel_info_from_sample_inputs_fn
def center_crop_segmentation_mask():
for mask, output_size in itertools.product(
make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))),
make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))),
[[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size
):
yield SampleInput(mask, output_size)
Expand Down Expand Up @@ -654,10 +688,20 @@ def test_scriptable(kernel):
feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"}
)
and name
not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"}
not in {
"to_image_tensor",
"InterpolationMode",
"decode_video_with_av",
"crop",
"rotate",
"perspective",
"elastic_transform",
"elastic",
}
# We skip 'crop' due to missing 'height' and 'width'
# We skip 'rotate' due to non implemented yet expand=True case for bboxes
# We skip 'perspective' as it requires different input args than perspective_image_tensor etc
# Skip 'elastic', TODO: inspect why test is failing
],
)
def test_functional_mid_level(func):
Expand All @@ -670,7 +714,9 @@ def test_functional_mid_level(func):
if key in kwargs:
del kwargs[key]
output = func(*sample_input.args, **kwargs)
torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}")
torch.testing.assert_close(
output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}"
)
break


Expand Down Expand Up @@ -1739,5 +1785,49 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s
torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor)
)

out = fn(tensor, kernel_size=ksize, sigma=sigma)
image = features.Image(tensor)

out = fn(image, kernel_size=ksize, sigma=sigma)
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")


@pytest.mark.parametrize("device", cpu_and_gpu())
@pytest.mark.parametrize(
"fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)]
)
def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
in_box = [10, 15, 25, 35]
for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))):
c, h, w = sample.shape[-3:]
# Setup a dummy image with 4 points
sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c]
sample = sample.to(device)

if fn == F.elastic_image_tensor:
sample = features.Image(sample)
kwargs = {"interpolation": F.InterpolationMode.NEAREST}
else:
sample = features.SegmentationMask(sample)
kwargs = {}

# Create a displacement grid using sin
n, m = 5.0, 0.1
d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h)
d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w)

d1 = d1[:, None].expand((h, w))
d2 = d2[None, :].expand((h, w))

displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1)
displacement = displacement.reshape(1, h, w, 2)

output = fn(sample, displacement=displacement, **kwargs)

# Check places where transformed points should be
torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]])
torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1])
torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]])
torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1])
11 changes: 11 additions & 0 deletions torchvision/prototype/features/_bounding_box.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,14 @@ def perspective(

output = _F.perspective_bounding_box(self, self.format, perspective_coeffs)
return BoundingBox.new_like(self, output, dtype=output.dtype)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> BoundingBox:
from torchvision.prototype.transforms import functional as _F

output = _F.elastic_bounding_box(self, self.format, displacement)
return BoundingBox.new_like(self, output, dtype=output.dtype)
8 changes: 8 additions & 0 deletions torchvision/prototype/features/_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,14 @@ def perspective(
) -> Any:
return self

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Any:
return self

def adjust_brightness(self, brightness_factor: float) -> Any:
return self

Expand Down
13 changes: 13 additions & 0 deletions torchvision/prototype/features/_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,19 @@ def perspective(
output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> Image:
from torchvision.prototype.transforms.functional import _geometry as _F

fill = _F._convert_fill_arg(fill)

output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill)
return Image.new_like(self, output)

def adjust_brightness(self, brightness_factor: float) -> Image:
from torchvision.prototype.transforms import functional as _F

Expand Down
12 changes: 12 additions & 0 deletions torchvision/prototype/features/_segmentation_mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import List, Optional, Union, Sequence

import torch
from torchvision.transforms import InterpolationMode

from ._feature import _Feature
Expand Down Expand Up @@ -119,3 +120,14 @@ def perspective(

output = _F.perspective_segmentation_mask(self, perspective_coeffs)
return SegmentationMask.new_like(self, output)

def elastic(
self,
displacement: torch.Tensor,
interpolation: InterpolationMode = InterpolationMode.NEAREST,
fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None,
) -> SegmentationMask:
from torchvision.prototype.transforms import functional as _F

output = _F.elastic_segmentation_mask(self, displacement)
return SegmentationMask.new_like(self, output, dtype=output.dtype)
3 changes: 1 addition & 2 deletions torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@
RandomRotation,
RandomAffine,
RandomPerspective,
ElasticTransform,
)
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda
from ._type_conversion import DecodeImage, LabelToOneHot

from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip

# TODO: add RandomPerspective, ElasticTransform
Loading