From e21ff74c403a0a2650c0d0eef55577833af1839b Mon Sep 17 00:00:00 2001 From: vfdev-5 Date: Mon, 18 Jul 2022 16:58:00 +0200 Subject: [PATCH] Added RandomPerspective and tests - replaced real image creation by mocks for other tests --- test/test_prototype_transforms.py | 108 +++++++++++++++--- test/test_prototype_transforms_functional.py | 4 +- torchvision/prototype/transforms/__init__.py | 1 + torchvision/prototype/transforms/_geometry.py | 58 ++++++++++ .../transforms/functional/_geometry.py | 6 +- 5 files changed, 157 insertions(+), 20 deletions(-) diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index d6987f6b71b..52becda6c37 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -344,7 +344,7 @@ def test__transform(self, padding, fill, padding_mode, mocker): transform = transforms.Pad(padding, fill=fill, padding_mode=padding_mode) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") - inpt = mocker.MagicMock(spec=torch.Tensor) + inpt = mocker.MagicMock(spec=features.Image) _ = transform(inpt) fn.assert_called_once_with(inpt, padding=padding, fill=fill, padding_mode=padding_mode) @@ -363,11 +363,12 @@ def test_assertions(self): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) - def test__get_params(self, fill, side_range): + def test__get_params(self, fill, side_range, mocker): transform = transforms.RandomZoomOut(fill=fill, side_range=side_range) - image = features.Image(torch.rand(1, 3, 32, 32)) - c, h, w = image.shape[-3:] + image = mocker.MagicMock(spec=features.Image) + c = image.num_channels = 3 + h, w = image.image_size = (24, 32) params = transform._get_params(image) @@ -381,19 +382,22 @@ def test__get_params(self, fill, side_range): @pytest.mark.parametrize("fill", [0, [1, 2, 3], (2, 3, 4)]) @pytest.mark.parametrize("side_range", [(1.0, 4.0), [2.0, 5.0]]) def test__transform(self, fill, side_range, mocker): - image = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + transform = transforms.RandomZoomOut(fill=fill, side_range=side_range, p=1) fn = mocker.patch("torchvision.prototype.transforms.functional.pad") # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) - _ = transform(image) + _ = transform(inpt) torch.manual_seed(12) torch.rand(1) # random apply changes random state - params = transform._get_params(image) + params = transform._get_params(inpt) - fn.assert_called_once_with(image, **params) + fn.assert_called_once_with(inpt, **params) class TestRandomRotation: @@ -443,7 +447,7 @@ def test__transform(self, degrees, expand, fill, center, mocker): assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.rotate") - inpt = mocker.MagicMock(spec=torch.Tensor) + inpt = mocker.MagicMock(spec=features.Image) # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -498,9 +502,11 @@ def test_assertions(self): @pytest.mark.parametrize("translate", [None, [0.1, 0.2]]) @pytest.mark.parametrize("scale", [None, [0.7, 1.2]]) @pytest.mark.parametrize("shear", [None, 2.0, [5.0, 15.0], [1.0, 2.0, 3.0, 4.0]]) - def test__get_params(self, degrees, translate, scale, shear): - image = features.Image(torch.rand(1, 3, 32, 32)) - h, w = image.shape[-2:] + def test__get_params(self, degrees, translate, scale, shear, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + h, w = image.image_size transform = transforms.RandomAffine(degrees, translate=translate, scale=scale, shear=shear) params = transform._get_params(image) @@ -558,7 +564,10 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker assert transform.degrees == [float(-degrees), float(degrees)] fn = mocker.patch("torchvision.prototype.transforms.functional.affine") - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -586,9 +595,11 @@ def test_assertions(self): with pytest.raises(ValueError, match="Padding mode should be either"): transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") - def test__get_params(self): - image = features.Image(torch.rand(1, 3, 32, 32)) - h, w = image.shape[-2:] + def test__get_params(self, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + h, w = image.image_size transform = transforms.RandomCrop([10, 10]) params = transform._get_params(image) @@ -608,7 +619,10 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode ) - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (32, 32) + expected = mocker.MagicMock(spec=features.Image) expected.num_channels = 3 if isinstance(padding, int): @@ -690,7 +704,10 @@ def test__transform(self, kernel_size, sigma, mocker): assert transform.sigma == (sigma, sigma) fn = mocker.patch("torchvision.prototype.transforms.functional.gaussian_blur") - inpt = features.Image(torch.rand(1, 3, 32, 32)) + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute # This could be also helpful for users torch.manual_seed(12) @@ -699,3 +716,58 @@ def test__transform(self, kernel_size, sigma, mocker): params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params) + + +class TestRandomPerspective: + def test_assertions(self): + with pytest.raises(ValueError, match="Argument distortion_scale value should be between 0 and 1"): + transforms.RandomPerspective(distortion_scale=-1.0) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomPerspective(0.5, fill="abc") + + def test__get_params(self, mocker): + dscale = 0.5 + transform = transforms.RandomPerspective(dscale) + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + params = transform._get_params(image) + + h, w = image.image_size + assert len(params["startpoints"]) == 4 + for x, y in params["startpoints"]: + assert x in (0, w - 1) + assert y in (0, h - 1) + + assert len(params["endpoints"]) == 4 + for (x, y), name in zip(params["endpoints"], ["tl", "tr", "br", "bl"]): + if "t" in name: + assert 0 <= y <= int(dscale * h // 2), (x, y, name) + if "b" in name: + assert h - int(dscale * h // 2) - 1 <= y <= h, (x, y, name) + if "l" in name: + assert 0 <= x <= int(dscale * w // 2), (x, y, name) + if "r" in name: + assert w - int(dscale * w // 2) - 1 <= x <= w, (x, y, name) + + @pytest.mark.parametrize("distortion_scale", [0.1, 0.7]) + def test__transform(self, distortion_scale, mocker): + interpolation = InterpolationMode.BILINEAR + fill = 12 + transform = transforms.RandomPerspective(distortion_scale, fill=fill, interpolation=interpolation) + + fn = mocker.patch("torchvision.prototype.transforms.functional.perspective") + inpt = mocker.MagicMock(spec=features.Image) + inpt.num_channels = 3 + inpt.image_size = (24, 32) + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + torch.rand(1) # random apply changes random state + params = transform._get_params(inpt) + + fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) diff --git a/test/test_prototype_transforms_functional.py b/test/test_prototype_transforms_functional.py index e39eb4b6632..16ca67168f3 100644 --- a/test/test_prototype_transforms_functional.py +++ b/test/test_prototype_transforms_functional.py @@ -548,9 +548,11 @@ def test_scriptable(kernel): and all( feature_type not in name for feature_type in {"image", "segmentation_mask", "bounding_box", "label", "pil"} ) - and name not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate"} + and name + not in {"to_image_tensor", "InterpolationMode", "decode_video_with_av", "crop", "rotate", "perspective"} # We skip 'crop' due to missing 'height' and 'width' # We skip 'rotate' due to non implemented yet expand=True case for bboxes + # We skip 'perspective' as it requires different input args than perspective_image_tensor etc ], ) def test_functional_mid_level(func): diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index f77b36d4643..fc8664cddbb 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -20,6 +20,7 @@ RandomZoomOut, RandomRotation, RandomAffine, + RandomPerspective, ) from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, GaussianBlur, Normalize, ToDtype, Lambda diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 88a118dbc9a..3cf3858720e 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -292,6 +292,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: bottom = canvas_height - (top + orig_h) padding = [left, top, right, bottom] + # vfdev-5: Can we put that into pad_image_tensor ? fill = self.fill if not isinstance(fill, collections.abc.Sequence): fill = [fill] * orig_c @@ -493,3 +494,60 @@ def forward(self, *inputs: Any) -> Any: flat_inputs, spec = tree_flatten(sample) out_flat_inputs = self._forward(flat_inputs) return tree_unflatten(out_flat_inputs, spec) + + +class RandomPerspective(_RandomApplyTransform): + def __init__( + self, + distortion_scale: float, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + interpolation: InterpolationMode = InterpolationMode.BILINEAR, + p: float = 0.5, + ) -> None: + super().__init__(p=p) + + _check_fill_arg(fill) + if not (0 <= distortion_scale <= 1): + raise ValueError("Argument distortion_scale value should be between 0 and 1") + + self.distortion_scale = distortion_scale + self.interpolation = interpolation + self.fill = fill + + def _get_params(self, sample: Any) -> Dict[str, Any]: + # Get image size + # TODO: make it work with bboxes and segm masks + image = query_image(sample) + _, height, width = get_image_dimensions(image) + + distortion_scale = self.distortion_scale + + half_height = height // 2 + half_width = width // 2 + topleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + topright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(0, int(distortion_scale * half_height) + 1, size=(1,)).item()), + ] + botright = [ + int(torch.randint(width - int(distortion_scale * half_width) - 1, width, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + botleft = [ + int(torch.randint(0, int(distortion_scale * half_width) + 1, size=(1,)).item()), + int(torch.randint(height - int(distortion_scale * half_height) - 1, height, size=(1,)).item()), + ] + startpoints = [[0, 0], [width - 1, 0], [width - 1, height - 1], [0, height - 1]] + endpoints = [topleft, topright, botright, botleft] + return dict(startpoints=startpoints, endpoints=endpoints) + + def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: + return F.perspective( + inpt, + **params, + fill=self.fill, + interpolation=self.interpolation, + ) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 8d3ed675047..87419ba8640 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -11,6 +11,7 @@ _get_inverse_affine_matrix, InterpolationMode, _compute_output_size, + _get_perspective_coeffs, ) from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil @@ -765,10 +766,13 @@ def perspective_segmentation_mask(img: torch.Tensor, perspective_coeffs: List[fl def perspective( inpt: DType, - perspective_coeffs: List[float], + startpoints: List[List[int]], + endpoints: List[List[int]], interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: Optional[Union[int, float, Sequence[int], Sequence[float]]] = None, ) -> DType: + perspective_coeffs = _get_perspective_coeffs(startpoints, endpoints) + if isinstance(inpt, features._Feature): return inpt.perspective(perspective_coeffs, interpolation=interpolation, fill=fill) elif isinstance(inpt, PIL.Image.Image):