From 3169d74825f30fd95db0f112c4c183cd81ab61fe Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 18 Jul 2022 18:47:27 +0200 Subject: [PATCH 1/6] WIP [proto] Added functional elastic transform with tests --- .../prototype/features/_bounding_box.py | 11 ++++ torchvision/prototype/features/_feature.py | 8 +++ torchvision/prototype/features/_image.py | 13 ++++ .../prototype/features/_segmentation_mask.py | 12 ++++ .../transforms/functional/__init__.py | 6 ++ .../transforms/functional/_geometry.py | 59 ++++++++++++++++++- 6 files changed, 107 insertions(+), 2 deletions(-) diff --git a/torchvision/prototype/features/_bounding_box.py b/torchvision/prototype/features/_bounding_box.py index eb9d1f6ac3a..f1957df5f1d 100644 --- a/torchvision/prototype/features/_bounding_box.py +++ b/torchvision/prototype/features/_bounding_box.py @@ -207,3 +207,14 @@ def perspective( output = _F.perspective_bounding_box(self, self.format, perspective_coeffs) return BoundingBox.new_like(self, output, dtype=output.dtype) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> BoundingBox: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_bounding_box(self, self.format, displacement) + return BoundingBox.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/features/_feature.py b/torchvision/prototype/features/_feature.py index 6013672d7ef..b2dedea86d3 100644 --- a/torchvision/prototype/features/_feature.py +++ b/torchvision/prototype/features/_feature.py @@ -157,6 +157,14 @@ def perspective( ) -> Any: return self + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Any: + return self + def adjust_brightness(self, brightness_factor: float) -> Any: return self diff --git a/torchvision/prototype/features/_image.py b/torchvision/prototype/features/_image.py index 0abda7b01d8..18cf7c1964d 100644 --- a/torchvision/prototype/features/_image.py +++ b/torchvision/prototype/features/_image.py @@ -244,6 +244,19 @@ def perspective( output = _F.perspective_image_tensor(self, perspective_coeffs, interpolation=interpolation, fill=fill) return Image.new_like(self, output) + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> Image: + from torchvision.prototype.transforms.functional import _geometry as _F + + fill = _F._convert_fill_arg(fill) + + output = _F.elastic_image_tensor(self, displacement, interpolation=interpolation, fill=fill) + return Image.new_like(self, output) + def adjust_brightness(self, brightness_factor: float) -> Image: from torchvision.prototype.transforms import functional as _F diff --git a/torchvision/prototype/features/_segmentation_mask.py b/torchvision/prototype/features/_segmentation_mask.py index f894f33d1b2..5f7ea80430b 100644 --- a/torchvision/prototype/features/_segmentation_mask.py +++ b/torchvision/prototype/features/_segmentation_mask.py @@ -2,6 +2,7 @@ from typing import List, Optional, Union, Sequence +import torch from torchvision.transforms import InterpolationMode from ._feature import _Feature @@ -119,3 +120,14 @@ def perspective( output = _F.perspective_segmentation_mask(self, perspective_coeffs) return SegmentationMask.new_like(self, output) + + def elastic( + self, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.NEAREST, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, + ) -> SegmentationMask: + from torchvision.prototype.transforms import functional as _F + + output = _F.elastic_segmentation_mask(self, displacement) + return SegmentationMask.new_like(self, output, dtype=output.dtype) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index 2d2618cf497..638049f96fb 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -89,6 +89,12 @@ perspective_image_tensor, perspective_image_pil, perspective_segmentation_mask, + elastic, + elastic_transform, + elastic_bounding_box, + elastic_image_tensor, + elastic_image_pil, + elastic_segmentation_mask, vertical_flip, vertical_flip_image_tensor, vertical_flip_image_pil, diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8d3ed675047..ce413914ea0 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -11,6 +11,8 @@ _get_inverse_affine_matrix, InterpolationMode, _compute_output_size, + pil_to_tensor, + to_pil_image, ) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -759,8 +761,8 @@ def perspective_bounding_box( ).view(original_shape) -def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: - return perspective_image_tensor(img, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) +def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: + return perspective_image_tensor(mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) def perspective( @@ -779,6 +781,59 @@ def perspective( return perspective_image_tensor(inpt, perspective_coeffs, interpolation=interpolation, fill=fill) +def elastic_image_tensor( + img: torch.Tensor, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[List[float]] = None, +) -> torch.Tensor: + return _FT.elastic_transform(img, displacement, interpolation=interpolation.value, fill=fill) + + +def elastic_image_pil( + img: PIL.Image.Image, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> PIL.Image.Image: + t_img = pil_to_tensor(img) + fill = _convert_fill_arg(fill) + + output = elastic_image_tensor(t_img, displacement, interpolation=interpolation, fill=fill) + return to_pil_image(output, mode=img.mode) + + +def elastic_bounding_box( + bounding_box: torch.Tensor, + format: features.BoundingBoxFormat, + displacement: torch.Tensor, +) -> torch.Tensor: + pass + + +def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: + return elastic_image_tensor(mask, displacement=displacement, interpolation=InterpolationMode.NEAREST) + + +def elastic( + inpt: DType, + displacement: torch.Tensor, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, +) -> DType: + if isinstance(inpt, features._Feature): + return inpt.elastic(displacement, interpolation=interpolation, fill=fill) + elif isinstance(inpt, PIL.Image.Image): + return elastic_image_pil(inpt, displacement, interpolation=interpolation, fill=fill) + else: + fill = _convert_fill_arg(fill) + + return elastic_image_tensor(inpt, displacement, interpolation=interpolation, fill=fill) + + +elastic_transform = elastic + + def _center_crop_parse_output_size(output_size: List[int]) -> List[int]: if isinstance(output_size, numbers.Number): return [int(output_size), int(output_size)] From b688a532bf59b829e04a27c74d15ee1e43e28376 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 18 Jul 2022 19:02:56 +0200 Subject: [PATCH 2/6] Added more functional tests --- test/test_prototype_transforms_functional.py | 72 +++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 873516869f8..61d1adfab18 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -200,6 +200,30 @@ def horizontal_flip_bounding_box(): yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) +@register_kernel_info_from_sample_inputs_fn +def horizontal_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_image_tensor(): + for image in make_images(): + yield SampleInput(image) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_bounding_box(): + for bounding_box in make_bounding_boxes(formats=[features.BoundingBoxFormat.XYXY]): + yield SampleInput(bounding_box, format=bounding_box.format, image_size=bounding_box.image_size) + + +@register_kernel_info_from_sample_inputs_fn +def vertical_flip_segmentation_mask(): + for mask in make_segmentation_masks(): + yield SampleInput(mask) + + @register_kernel_info_from_sample_inputs_fn def resize_image_tensor(): for image, interpolation, max_size, antialias in itertools.product( @@ -404,9 +428,17 @@ def crop_segmentation_mask(): @register_kernel_info_from_sample_inputs_fn -def vertical_flip_segmentation_mask(): - for mask in make_segmentation_masks(): - yield SampleInput(mask) +def resized_crop_image_tensor(): + for mask, top, left, height, width, size, antialias in itertools.product( + make_images(), + [-8, 9], + [-8, 9], + [12], + [12], + [(16, 18)], + [True, False], + ): + yield SampleInput(mask, top=top, left=left, height=height, width=width, size=size, antialias=antialias) @register_kernel_info_from_sample_inputs_fn @@ -457,6 +489,19 @@ def pad_bounding_box(): yield SampleInput(bounding_box, padding=padding, format=bounding_box.format) +@register_kernel_info_from_sample_inputs_fn +def perspective_image_tensor(): + for image, perspective_coeffs, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [ + [1.2405, 0.1772, -6.9113, 0.0463, 1.251, -5.235, 0.00013, 0.0018], + [0.7366, -0.11724, 1.45775, -0.15012, 0.73406, 2.6019, -0.0072, -0.0063], + ], + [None, [128], [12.0]], # fill + ): + yield SampleInput(image, perspective_coeffs=perspective_coeffs, fill=fill) + + @register_kernel_info_from_sample_inputs_fn def perspective_bounding_box(): for bounding_box, perspective_coeffs in itertools.product( @@ -488,6 +533,15 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def center_crop_image_tensor(): + for mask, output_size in itertools.product( + make_images(sizes=((16, 16), (7, 33), (31, 9))), + [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size + ): + yield SampleInput(mask, output_size) + + @register_kernel_info_from_sample_inputs_fn def center_crop_bounding_box(): for bounding_box, output_size in itertools.product(make_bounding_boxes(), [(24, 12), [16, 18], [46, 48], [12]]): @@ -1181,6 +1235,18 @@ def _compute_expected_mask(mask, top_, left_, height_, width_): torch.testing.assert_close(output_mask, expected_mask) +@pytest.mark.parametrize("device", cpu_and_gpu()) +def test_correctness_horizontal_flip_segmentation_mask_on_fixed_input(device): + mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + mask[:, :, 0] = 1 + + out_mask = F.horizontal_flip_segmentation_mask(mask) + + expected_mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) + expected_mask[:, :, -1] = 1 + torch.testing.assert_close(out_mask, expected_mask) + + @pytest.mark.parametrize("device", cpu_and_gpu()) def test_correctness_vertical_flip_segmentation_mask_on_fixed_input(device): mask = torch.zeros((3, 3, 3), dtype=torch.long, device=device) From ea6821ad00b9435730bbaed6f1f0c4ffbd201a12 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 18 Jul 2022 19:43:05 +0200 Subject: [PATCH 3/6] WIP on elastic op --- test/test_prototype_transforms_functional.py | 45 ++++++++++++++++++- .../transforms/functional/_geometry.py | 5 ++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 61d1adfab18..6e19091f861 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -533,6 +533,40 @@ def perspective_segmentation_mask(): ) +@register_kernel_info_from_sample_inputs_fn +def elastic_image_tensor(): + for image, fill in itertools.product( + make_images(extra_dims=((), (4,))), + [None, [128], [12.0]], # fill + ): + h, w = image.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput(image, displacement=displacement, fill=fill) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_bounding_box(): + for bounding_box in make_bounding_boxes(): + h, w = bounding_box.image_size + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + bounding_box, + format=bounding_box.format, + displacement=displacement, + ) + + +@register_kernel_info_from_sample_inputs_fn +def elastic_segmentation_mask(): + for mask in make_segmentation_masks(extra_dims=((), (4,))): + h, w = mask.shape[-2:] + displacement = torch.rand(1, h, w, 2) + yield SampleInput( + mask, + displacement=displacement, + ) + + @register_kernel_info_from_sample_inputs_fn def center_crop_image_tensor(): for mask, output_size in itertools.product( @@ -654,10 +688,19 @@ def test_scriptable(kernel): feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} ) and name - not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"} + not in { + "to_image_tensor", + "InterpolationMode", + "decode_video_with_av", + "crop", + "rotate", + "perspective", + "elastic_transform", + } # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes # We skip 'perspective' as it requires different input args than perspective_image_tensor etc + # Skip 'elastic', TODO: inspect why test is failing ], ) def test_functional_mid_level(func): diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index d5a2e3b513f..f96404758ff 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -763,7 +763,9 @@ def perspective_bounding_box( def perspective_segmentation_mask(mask: torch.Tensor, perspective_coeffs: List[float]) -> torch.Tensor: - return perspective_image_tensor(mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST) + return perspective_image_tensor( + mask, perspective_coeffs=perspective_coeffs, interpolation=InterpolationMode.NEAREST + ) def perspective( @@ -812,6 +814,7 @@ def elastic_bounding_box( format: features.BoundingBoxFormat, displacement: torch.Tensor, ) -> torch.Tensor: + # TODO: implement transformation pass From d78b611ae9efb33ac817873ff995a10663cc9f1e Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Thu, 21 Jul 2022 10:08:58 +0200 Subject: [PATCH 4/6] Added elastic transform and tests --- test/test_functional_tensor.py | 18 +++-- test/test_prototype_transforms_functional.py | 5 +- torchvision/prototype/transforms/__init__.py | 3 +- torchvision/prototype/transforms/_geometry.py | 69 +++++++++++++++++++ .../transforms/functional/_geometry.py | 33 ++++++++- torchvision/transforms/functional.py | 13 +++- torchvision/transforms/functional_tensor.py | 10 ++- 7 files changed, 136 insertions(+), 15 deletions(-) diff --git a/test/test_functional_tensor.py b/test/test_functional_tensor.py index 969aedf6d2d..bec868c88fd 100644 --- a/test/test_functional_tensor.py +++ b/test/test_functional_tensor.py @@ -1352,16 +1352,24 @@ def test_ten_crop(device): assert_equal(transformed_batch, s_transformed_batch) +def test_elastic_transform_asserts(): + with pytest.raises(TypeError, match="Argument displacement should be a Tensor"): + _ = F.elastic_transform("abc", displacement=None) + + with pytest.raises(TypeError, match="img should be PIL Image or Tensor"): + _ = F.elastic_transform("abc", displacement=torch.rand(1)) + + img_tensor = torch.rand(1, 3, 32, 24) + with pytest.raises(ValueError, match="Argument displacement shape should"): + _ = F.elastic_transform(img_tensor, displacement=torch.rand(1, 2)) + + @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("interpolation", [NEAREST, BILINEAR, BICUBIC]) @pytest.mark.parametrize("dt", [None, torch.float32, torch.float64, torch.float16]) @pytest.mark.parametrize( "fill", - [ - None, - [255, 255, 255], - (2.0,), - ], + [None, [255, 255, 255], (2.0,)], ) def test_elastic_transform_consistency(device, interpolation, dt, fill): script_elastic_transform = torch.jit.script(F.elastic_transform) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 6e19091f861..1302f2cc2c7 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -696,6 +696,7 @@ def test_scriptable(kernel): "rotate", "perspective", "elastic_transform", + "elastic", } # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes @@ -713,7 +714,9 @@ def test_functional_mid_level(func): if key in kwargs: del kwargs[key] output = func(*sample_input.args, **kwargs) - torch.testing.assert_close(output, expected, msg=f"finfo={finfo}, output={output}, expected={expected}") + torch.testing.assert_close( + output, expected, msg=f"finfo={finfo.name}, output={output}, expected={expected}" + ) break diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index c41171a05be..483f811ef7d 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -30,11 +30,10 @@ RandomRotation, RandomAffine, RandomPerspective, + ElasticTransform, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda from ._type_conversion import DecodeImage, LabelToOneHot from ._deprecated import Grayscale, RandomGrayscale, ToTensor, ToPILImage, PILToTensor # usort: skip - -# TODO: add RandomPerspective, ElasticTransform diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 3cf3858720e..527af9ba014 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -551,3 +551,72 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, interpolation=self.interpolation, ) + + +def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: int = 2) -> Sequence[float]: + if not isinstance(arg, (float, Sequence)): + raise TypeError(f"{name} should be float or a sequence of floats. Got {type(arg)}") + if isinstance(arg, Sequence) and len(arg) != req_size: + raise ValueError(f"If {name} is a sequence its length should be one of {req_size}. Got {len(arg)}") + if isinstance(arg, Sequence): + for element in arg: + if not isinstance(element, float): + raise TypeError(f"{name} should be a sequence of floats. Got {type(element)}") + + if isinstance(arg, float): + arg = [float(arg), float(arg)] + if isinstance(arg, (list, tuple)) and len(arg) == 1: + arg = [arg[0], arg[0]] + return arg + + +class ElasticTransform(Transform): + def __init__( + self, + alpha: Union[float, Sequence[float]] = 50.0, + sigma: Union[float, Sequence[float]] = 5.0, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + ) -> None: + super().__init__() + self.alpha = _setup_float_or_seq(alpha, "alpha", 2) + self.sigma = _setup_float_or_seq(sigma, "sigma", 2) + + _check_fill_arg(fill) + + self.interpolation = interpolation + self.fill = fill + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, *size = get_image_dimensions(image) + + dx = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[0] > 0.0: + kx = int(8 * self.sigma[0] + 1) + # if kernel size is even we have to make it odd + if kx % 2 == 0: + kx += 1 + dx = F.gaussian_blur(dx, [kx, kx], list(self.sigma)) + dx = dx * self.alpha[0] / size[0] + + dy = torch.rand([1, 1] + size) * 2 - 1 + if self.sigma[1] > 0.0: + ky = int(8 * self.sigma[1] + 1) + # if kernel size is even we have to make it odd + if ky % 2 == 0: + ky += 1 + dy = F.gaussian_blur(dy, [ky, ky], list(self.sigma)) + dy = dy * self.alpha[1] / size[1] + displacement = torch.concat([dx, dy], 1).permute([0, 2, 3, 1]) # 1 x H x W x 2 + return dict(displacement=displacement) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.elastic( + inpt, + **params, + fill=self.fill, + interpolation=self.interpolation, + ) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index f96404758ff..46d1afa1a20 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -814,8 +814,37 @@ def elastic_bounding_box( format: features.BoundingBoxFormat, displacement: torch.Tensor, ) -> torch.Tensor: - # TODO: implement transformation - pass + displacement = displacement.to(bounding_box.device) + + original_shape = bounding_box.shape + bounding_box = convert_bounding_box_format( + bounding_box, old_format=format, new_format=features.BoundingBoxFormat.XYXY + ).view(-1, 4) + + # Question (vfdev-5): should we rely on good displacement shape and fetch image size from it + # Or add image_size arg and check displacement shape + image_size = displacement.shape[-3], displacement.shape[-2] + + id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) + # We construct inverse grid vs grid = id_grid + displacement used for images + inv_grid = id_grid - displacement + + # Get points from bboxes + points = bounding_box[:, [[0, 1], [2, 1], [2, 3], [0, 3]]].view(-1, 2) + index_x = torch.floor(points[:, 0] + 0.5).to(dtype=torch.long) + index_y = torch.floor(points[:, 1] + 0.5).to(dtype=torch.long) + # Transform points: + t_size = torch.tensor(image_size[::-1], device=displacement.device, dtype=displacement.dtype) + transformed_points = (inv_grid[0, index_y, index_x, :] + 1) * 0.5 * t_size - 0.5 + + transformed_points = transformed_points.view(-1, 4, 2) + out_bbox_mins, _ = torch.min(transformed_points, dim=1) + out_bbox_maxs, _ = torch.max(transformed_points, dim=1) + out_bboxes = torch.cat([out_bbox_mins, out_bbox_maxs], dim=1) + + return convert_bounding_box_format( + out_bboxes, old_format=features.BoundingBoxFormat.XYXY, new_format=format, copy=False + ).view(original_shape) def elastic_segmentation_mask(mask: torch.Tensor, displacement: torch.Tensor) -> torch.Tensor: diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index eea53a228a9..442e8d4288d 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -1555,7 +1555,7 @@ def elastic_transform( If img is torch Tensor, it is expected to be in [..., 1 or 3, H, W] format, where ... means it can have an arbitrary number of leading dimensions. If img is PIL Image, it is expected to be in mode "P", "L" or "RGB". - displacement (Tensor): The displacement field. + displacement (Tensor): The displacement field. Expected shape is [1, H, W, 2]. interpolation (InterpolationMode): Desired interpolation enum defined by :class:`torchvision.transforms.InterpolationMode`. Default is ``InterpolationMode.BILINEAR``. @@ -1577,7 +1577,7 @@ def elastic_transform( interpolation = _interpolation_modes_from_int(interpolation) if not isinstance(displacement, torch.Tensor): - raise TypeError("displacement should be a Tensor") + raise TypeError("Argument displacement should be a Tensor") t_img = img if not isinstance(img, torch.Tensor): @@ -1585,6 +1585,15 @@ def elastic_transform( raise TypeError(f"img should be PIL Image or Tensor. Got {type(img)}") t_img = pil_to_tensor(img) + shape = t_img.shape + shape = (1,) + shape[-2:] + (2,) + if shape != displacement.shape: + raise ValueError(f"Argument displacement shape should be {shape}, but given {displacement.shape}") + + # TODO: if image shape is [N1, N2, ..., C, H, W] and + # displacement is [1, H, W, 2] we need to reshape input image + # such grid_sampler takes internal code for 4D input + output = F_t.elastic_transform( t_img, displacement, diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 2b0872acf8a..c35edfb74b0 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -932,6 +932,12 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool return img +def _create_identity_grid(size: List[int]) -> Tensor: + hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] + grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") + return torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + + def elastic_transform( img: Tensor, displacement: Tensor, @@ -945,8 +951,6 @@ def elastic_transform( size = list(img.shape[-2:]) displacement = displacement.to(img.device) - hw_space = [torch.linspace((-s + 1) / s, (s - 1) / s, s) for s in size] - grid_y, grid_x = torch.meshgrid(hw_space, indexing="ij") - identity_grid = torch.stack([grid_x, grid_y], -1).unsqueeze(0) # 1 x H x W x 2 + identity_grid = _create_identity_grid(size) grid = identity_grid.to(img.device) + displacement return _apply_grid_transform(img, grid, interpolation, fill) From 2099d00a751aac654b5b37c01ece0a7c1eebe701 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 22 Jul 2022 15:26:18 +0200 Subject: [PATCH 5/6] Added tests --- test/test_prototype_transforms_functional.py | 57 +++++++++++++++++-- .../transforms/functional/_geometry.py | 3 +- 2 files changed, 53 insertions(+), 7 deletions(-) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index 1302f2cc2c7..f951eece6cc 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -59,7 +59,7 @@ def make_images( yield make_image(size, color_space=color_space, dtype=dtype) for color_space, dtype, extra_dims_ in itertools.product(color_spaces, dtypes, extra_dims): - yield make_image(color_space=color_space, extra_dims=extra_dims_, dtype=dtype) + yield make_image(size=sizes[0], color_space=color_space, extra_dims=extra_dims_, dtype=dtype) def randint_with_tensor_bounds(arg1, arg2=None, **kwargs): @@ -149,12 +149,12 @@ def make_segmentation_mask(size=None, *, num_categories=80, extra_dims=(), dtype def make_segmentation_masks( - image_sizes=((16, 16), (7, 33), (31, 9)), + sizes=((16, 16), (7, 33), (31, 9)), dtypes=(torch.long,), extra_dims=((), (4,), (2, 3)), ): - for image_size, dtype, extra_dims_ in itertools.product(image_sizes, dtypes, extra_dims): - yield make_segmentation_mask(size=image_size, dtype=dtype, extra_dims=extra_dims_) + for size, dtype, extra_dims_ in itertools.product(sizes, dtypes, extra_dims): + yield make_segmentation_mask(size=size, dtype=dtype, extra_dims=extra_dims_) class SampleInput: @@ -587,7 +587,7 @@ def center_crop_bounding_box(): @register_kernel_info_from_sample_inputs_fn def center_crop_segmentation_mask(): for mask, output_size in itertools.product( - make_segmentation_masks(image_sizes=((16, 16), (7, 33), (31, 9))), + make_segmentation_masks(sizes=((16, 16), (7, 33), (31, 9))), [[4, 3], [42, 70], [4]], # crop sizes < image sizes, crop_sizes > image sizes, single crop size ): yield SampleInput(mask, output_size) @@ -1785,5 +1785,50 @@ def test_correctness_gaussian_blur_image_tensor(device, image_size, dt, ksize, s torch.tensor(true_cv2_results[gt_key]).reshape(shape[-2], shape[-1], shape[-3]).permute(2, 0, 1).to(tensor) ) - out = fn(tensor, kernel_size=ksize, sigma=sigma) + image = features.Image(tensor) + + out = fn(image, kernel_size=ksize, sigma=sigma) torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}") + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize( + "fn, make_samples", [(F.elastic_image_tensor, make_images), (F.elastic_segmentation_mask, make_segmentation_masks)] +) +def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): + in_box = [10, 15, 25, 35] + for sample in make_samples(sizes=((64, 76),), extra_dims=((), (4,))): + c, h, w = sample.shape[-3:] + # Setup a dummy image with 4 points + sample[..., in_box[1], in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[0]] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[3] - 1, in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample[..., in_box[1], in_box[2] - 1] = torch.tensor([12, 34, 96, 112])[:c] + sample = sample.to(device) + + if fn == F.elastic_image_tensor: + sample = features.Image(sample) + kwargs = {"interpolation": F.InterpolationMode.NEAREST} + else: + sample = features.SegmentationMask(sample) + kwargs = {} + + # Create a displacement grid using sin + n, m = 5.0, 0.1 + d1 = m * torch.sin(torch.arange(h, dtype=torch.float) * torch.pi * n / h) + d2 = m * torch.sin(torch.arange(w, dtype=torch.float) * torch.pi * n / w) + + d1 = d1[:, None].expand((h, w)) + d2 = d2[None, :].expand((h, w)) + + displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1) + displacement = displacement.reshape(1, h, w, 2) + + print(sample.dtype, sample.shape) + output = fn(sample, displacement=displacement, **kwargs) + + # Check places where transformed points should be + torch.testing.assert_close(output[..., 12, 9], sample[..., in_box[1], in_box[0]]) + torch.testing.assert_close(output[..., 17, 27], sample[..., in_box[1], in_box[2] - 1]) + torch.testing.assert_close(output[..., 31, 6], sample[..., in_box[3] - 1, in_box[0]]) + torch.testing.assert_close(output[..., 37, 23], sample[..., in_box[3] - 1, in_box[2] - 1]) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 46d1afa1a20..e48701bb436 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -826,7 +826,8 @@ def elastic_bounding_box( image_size = displacement.shape[-3], displacement.shape[-2] id_grid = _FT._create_identity_grid(list(image_size)).to(bounding_box.device) - # We construct inverse grid vs grid = id_grid + displacement used for images + # We construct an approximation of inverse grid as inv_grid = id_grid - displacement + # This is not an exact inverse of the grid inv_grid = id_grid - displacement # Get points from bboxes From e76011af11c29b8c4b9ccf95d026001c592ff670 Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Fri, 22 Jul 2022 18:11:27 +0200 Subject: [PATCH 6/6] Added tests for ElasticTransform --- test/test_prototype_transforms.py | 75 +++++++++++++++++++ test/test_prototype_transforms_functional.py | 1 - torchvision/prototype/transforms/_geometry.py | 2 +- 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index b8cfabcccc9..51eba38c7a6 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker): fn = mocker.patch("torchvision.prototype.transforms.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): inpt = mocker.MagicMock(spec=features.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker): # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -795,6 +800,7 @@ def test__transform(self, distortion_scale, mocker): inpt.image_size = (24, 32) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users + # Otherwise, we can mock transform._get_params torch.manual_seed(12) _ = transform(inpt) torch.manual_seed(12) @@ -802,3 +808,72 @@ def test__transform(self, distortion_scale, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) + + +class TestElasticTransform: + def test_assertions(self): + + with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"): + transforms.ElasticTransform({}) + + with pytest.raises(ValueError, match="alpha is a sequence its length should be one of 2"): + transforms.ElasticTransform([1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="alpha should be a sequence of floats"): + transforms.ElasticTransform([1, 2]) + + with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"): + transforms.ElasticTransform(1.0, {}) + + with pytest.raises(ValueError, match="sigma is a sequence its length should be one of 2"): + transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0]) + + with pytest.raises(ValueError, match="sigma should be a sequence of floats"): + transforms.ElasticTransform(1.0, [1, 2]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.ElasticTransform(1.0, 2.0, fill="abc") + + def test__get_params(self, mocker): + alpha = 2.0 + sigma = 3.0 + transform = transforms.ElasticTransform(alpha, sigma) + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + params = transform._get_params(image) + + h, w = image.image_size + displacement = params["displacement"] + assert displacement.shape == (1, h, w, 2) + assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all() + assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all() + + @pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]]) + @pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]]) + def test__transform(self, alpha, sigma, mocker): + interpolation = InterpolationMode.BILINEAR + fill = 12 + transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation) + + if isinstance(alpha, float): + assert transform.alpha == [alpha, alpha] + else: + assert transform.alpha == alpha + + if isinstance(sigma, float): + assert transform.sigma == [sigma, sigma] + else: + assert transform.sigma == sigma + + fn = mocker.patch("torchvision.prototype.transforms.functional.elastic") + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock() + _ = transform(inpt) + params = transform._get_params(inpt) + fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index f951eece6cc..fb5f10459fe 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -1824,7 +1824,6 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples): displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1) displacement = displacement.reshape(1, h, w, 2) - print(sample.dtype, sample.shape) output = fn(sample, displacement=displacement, **kwargs) # Check places where transformed points should be diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 527af9ba014..8a1b94060c4 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -561,7 +561,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size: if isinstance(arg, Sequence): for element in arg: if not isinstance(element, float): - raise TypeError(f"{name} should be a sequence of floats. Got {type(element)}") + raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}") if isinstance(arg, float): arg = [float(arg), float(arg)]