diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 969aedf6d2d..bec868c88fd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1352,16 +1352,24 @@ def test_ten_crop(device): assert_equal(transformed_batch, s_transformed_batch) +def test_elastic_transform_asserts(): + with pytest.raises(TypeError, match="Argument displacement should be a Tensor"): + _ = F.elastic_transform("abc", displacement=None) + + with pytest.raises(TypeError, match="img should be PIL Image or Tensor"): + _ = F.elastic_transform("abc", displacement=torch.rand(1)) + + img_tensor = torch.rand(1, 3, 32, 24) + with pytest.raises(ValueError, match="Argument displacement shape should"): + _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2)) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( "fill", - [ - None, - [255, 255, 255], - (2.0,), - ], + [None, [255, 255, 255], (2.0,)], ) def test_elastic_transform_consistency(device, interpolation, dt, fill): script_elastic_transform = torch.jit.script(F.elastic_transform) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b8cfabcccc9..51eba38c7a6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): inpt = mocker.MagicMock(spec=features.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -795,6 +800,7 @@ def test__transform(self, distortion_scale, mocker): inpt.image_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -802,3 +808,72 @@ def test__transform(self, distortion_scale, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) + + +class TestElasticTransform: + def test_assertions(self): + + with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): + transforms.ElasticTransform({}) + + with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): + transforms.ElasticTransform([1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="alpha should be a sequence of floats"): + transforms.ElasticTransform([1, 2]) + + with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"): + transforms.ElasticTransform(1.0, {}) + + with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): + transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="sigma should be a sequence of floats"): + transforms.ElasticTransform(1.0, [1, 2]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.ElasticTransform(1.0, 2.0, fill="abc") + + def test__get_params(self, mocker): + alpha = 2.0 + sigma = 3.0 + transform = transforms.ElasticTransform(alpha, sigma) + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + params = transform._get_params(image) + + h, w = image.image_size + displacement = params["displacement"] + assert displacement.shape == (1, h, w, 2) + assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() + assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() + + @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) + @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) + def test__transform(self, alpha, sigma, mocker): + interpolation = InterpolationMode.BILINEAR + fill = 12 + transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) + + if isinstance(alpha, float): + assert transform.alpha == [alpha, alpha] + else: + assert transform.alpha == alpha + + if isinstance(sigma, float): + assert transform.sigma == [sigma, sigma] + else: + assert transform.sigma == sigma + + fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock() + _ = transform(inpt) + params = transform._get_params(inpt) + fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 61d1adfab18..fb5f10459fe 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -59,7 +59,7 @@ def make_images( yield make_image(size, color_space=color_space, dtype=dtype) for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): - yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype) + yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype) def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): @@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype def make_segmentation_masks( - image_sizes=((16, 16), (7, 33), (31, 9)), + sizes=((16, 16), (7, 33), (31, 9)), dtypes=(torch.long,), extra_dims=((), (4,), (2, 3)), ): - for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): - yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) + for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) class SampleInput: @@ -533,6 +533,40 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def elastic_image_tensor(): + for image, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [None, [128], [12.0]], # fill + ): + h, w = image.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput(image, displacement=displacement, fill=fill) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_bounding_box(): + for bounding_box in make_bounding_boxes(): + h, w = bounding_box.image_size + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + bounding_box, + format=bounding_box.format, + displacement=displacement, + ) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_segmentation_mask(): + for mask in make_segmentation_masks(extra_dims=((), (4,))): + h, w = mask.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + mask, + displacement=displacement, + ) + + @register_kernel_info_from_sample_inputs_fn def center_crop_image_tensor(): for mask, output_size in itertools.product( @@ -553,7 +587,7 @@ def center_crop_bounding_box(): @register_kernel_info_from_sample_inputs_fn def center_crop_segmentation_mask(): for mask, output_size in itertools.product( - make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), + make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size ): yield SampleInput(mask, output_size) @@ -654,10 +688,20 @@ def test_scriptable(kernel): feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} ) and name - not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"} + not in { + "to_image_tensor", + "InterpolationMode", + "decode_video_with_av", + "crop", + "rotate", + "perspective", + "elastic_transform", + "elastic", + } # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc + # Skip 'elastic', TODO: inspect why test is failing ], ) def test_functional_mid_level(func): @@ -670,7 +714,9 @@ def test_functional_mid_level(func): if key in kwargs: del kwargs[key] output = func(*sample_input.args, **kwargs) - torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + torch.testing.assert_close( + output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}" + ) break @@ -1739,5 +1785,49 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) ) - out = fn(tensor, kernel_size=ksize, sigma=sigma) + image = features.Image(tensor) + + out = fn(image, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)] +) +def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): + in_box = [10, 15, 25, 35] + for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))): + c, h, w = sample.shape[-3:] + # Setup a dummy image with 4 points + sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample = sample.to(device) + + if fn == F.elastic_image_tensor: + sample = features.Image(sample) + kwargs = {"interpolation": F.InterpolationMode.NEAREST} + else: + sample = features.SegmentationMask(sample) + kwargs = {} + + # Create a displacement grid using sin + n, m = 5.0, 0.1 + d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h) + d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w) + + d1 = d1[:, None].expand((h, w)) + d2 = d2[None, :].expand((h, w)) + + displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1) + displacement = displacement.reshape(1, h, w, 2) + + output = fn(sample, displacement=displacement, **kwargs) + + # Check places where transformed points should be + torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]]) + torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1]) + torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]]) + torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1]) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index eb9d1f6ac3a..f1957df5f1d 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -207,3 +207,14 @@ def perspective( output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> BoundingBox: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_bounding_box(self, self.format, displacement) + return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 6013672d7ef..b2dedea86d3 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -157,6 +157,14 @@ def perspective( ) -> Any: return self + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Any: + return self + def adjust_brightness(self, brightness_factor: float) -> Any: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0abda7b01d8..18cf7c1964d 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -244,6 +244,19 @@ def perspective( output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) return Image.new_like(self, output) + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Image: + from torchvision.prototype.transforms.functional import _geometry as _F + + fill = _F._convert_fill_arg(fill) + + output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) + return Image.new_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: from torchvision.prototype.transforms import functional as _F diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index f894f33d1b2..5f7ea80430b 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union, Sequence +import torch from torchvision.transforms import InterpolationMode from ._feature import _Feature @@ -119,3 +120,14 @@ def perspective( output = _F.perspective_segmentation_mask(self, perspective_coeffs) return SegmentationMask.new_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> SegmentationMask: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_segmentation_mask(self, displacement) + return SegmentationMask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c41171a05be..483f811ef7d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -30,11 +30,10 @@ RandomRotation, RandomAffine, RandomPerspective, + ElasticTransform, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip - -# TODO: add RandomPerspective, ElasticTransform diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 3cf3858720e..8a1b94060c4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -551,3 +551,72 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: + if not isinstance(arg, (float, Sequence)): + raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) != req_size: + raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, float): + raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") + + if isinstance(arg, float): + arg = [float(arg), float(arg)] + if isinstance(arg, (list, tuple)) and len(arg) == 1: + arg = [arg[0], arg[0]] + return arg + + +class ElasticTransform(Transform): + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.alpha = _setup_float_or_seq(alpha, "alpha", 2) + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + _check_fill_arg(fill) + + self.interpolation = interpolation + self.fill = fill + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, *size = get_image_dimensions(image) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.elastic( + inpt, + **params, + fill=self.fill, + interpolation=self.interpolation, + ) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2d2618cf497..638049f96fb 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -89,6 +89,12 @@ perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + elastic, + elastic_transform, + elastic_bounding_box, + elastic_image_tensor, + elastic_image_pil, + elastic_segmentation_mask, vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 87419ba8640..e48701bb436 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -12,6 +12,8 @@ InterpolationMode, _compute_output_size, _get_perspective_coeffs, + pil_to_tensor, + to_pil_image, ) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -760,8 +762,10 @@ def perspective_bounding_box( ).view(original_shape) -def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: - return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: + return perspective_image_tensor( + mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST + ) def perspective( @@ -783,6 +787,90 @@ def perspective( return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) +def elastic_image_tensor( + img: torch.Tensor, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> torch.Tensor: + return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) + + +def elastic_image_pil( + img: PIL.Image.Image, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> PIL.Image.Image: + t_img = pil_to_tensor(img) + fill = _convert_fill_arg(fill) + + output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) + return to_pil_image(output, mode=img.mode) + + +def elastic_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + displacement: torch.Tensor, +) -> torch.Tensor: + displacement = displacement.to(bounding_box.device) + + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it + # Or add image_size arg and check displacement shape + image_size = displacement.shape[-3], displacement.shape[-2] + + id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) + # We construct an approximation of inverse grid as inv_grid = id_grid - displacement + # This is not an exact inverse of the grid + inv_grid = id_grid - displacement + + # Get points from bboxes + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) + index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) + # Transform points: + t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 + + transformed_points = transformed_points.view(-1, 4, 2) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) + + +def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: + return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) + + +def elastic( + inpt: DType, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> DType: + if isinstance(inpt, features._Feature): + return inpt.elastic(displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, PIL.Image.Image): + return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) + else: + fill = _convert_fill_arg(fill) + + return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) + + +elastic_transform = elastic + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index eea53a228a9..442e8d4288d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1555,7 +1555,7 @@ def elastic_transform( If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". - displacement (Tensor): The displacement field. + displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2]. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. @@ -1577,7 +1577,7 @@ def elastic_transform( interpolation = _interpolation_modes_from_int(interpolation) if not isinstance(displacement, torch.Tensor): - raise TypeError("displacement should be a Tensor") + raise TypeError("Argument displacement should be a Tensor") t_img = img if not isinstance(img, torch.Tensor): @@ -1585,6 +1585,15 @@ def elastic_transform( raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") t_img = pil_to_tensor(img) + shape = t_img.shape + shape = (1,) + shape[-2:] + (2,) + if shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}") + + # TODO: if image shape is [N1, N2, ..., C, H, W] and + # displacement is [1, H, W, 2] we need to reshape input image + # such grid_sampler takes internal code for 4D input + output = F_t.elastic_transform( t_img, displacement, diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2b0872acf8a..c35edfb74b0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -932,6 +932,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] + grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") + return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + + def elastic_transform( img: Tensor, displacement: Tensor, @@ -945,8 +951,6 @@ def elastic_transform( size = list(img.shape[-2:]) displacement = displacement.to(img.device) - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] - grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") - identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + identity_grid = _create_identity_grid(size) grid = identity_grid.to(img.device) + displacement return _apply_grid_transform(img, grid, interpolation, fill)