diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 899835ba276..d561705fdfe 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -82,6 +82,7 @@ class TestSmoke: transforms.RandomZoomOut(), transforms.RandomRotation(degrees=(-45, 45)), transforms.RandomAffine(degrees=(-45, 45)), + transforms.RandomCrop([16, 16], padding=1, pad_if_needed=True), ) def test_common(self, transform, input): transform(input) @@ -566,3 +567,80 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, interpolation=interpolation, fill=fill, center=center) + + +class TestRandomCrop: + def test_assertions(self): + with pytest.raises(ValueError, match="Please provide only two dimensions"): + transforms.RandomCrop([10, 12, 14]) + + with pytest.raises(TypeError, match="Got inappropriate padding arg"): + transforms.RandomCrop([10, 12], padding="abc") + + with pytest.raises(ValueError, match="Padding must be an int or a 1, 2, or 4"): + transforms.RandomCrop([10, 12], padding=[-0.7, 0, 0.7]) + + with pytest.raises(TypeError, match="Got inappropriate fill arg"): + transforms.RandomCrop([10, 12], padding=1, fill="abc") + + with pytest.raises(ValueError, match="Padding mode should be either"): + transforms.RandomCrop([10, 12], padding=1, padding_mode="abc") + + def test__get_params(self): + image = features.Image(torch.rand(1, 3, 32, 32)) + h, w = image.shape[-2:] + + transform = transforms.RandomCrop([10, 10]) + params = transform._get_params(image) + + assert 0 <= params["top"] <= h - transform.size[0] + 1 + assert 0 <= params["left"] <= w - transform.size[1] + 1 + assert params["height"] == 10 + assert params["width"] == 10 + + @pytest.mark.parametrize("padding", [None, 1, [2, 3], [1, 2, 3, 4]]) + @pytest.mark.parametrize("pad_if_needed", [False, True]) + @pytest.mark.parametrize("fill", [False, True]) + @pytest.mark.parametrize("padding_mode", ["constant", "edge"]) + def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker): + output_size = [10, 12] + transform = transforms.RandomCrop( + output_size, padding=padding, pad_if_needed=pad_if_needed, fill=fill, padding_mode=padding_mode + ) + + inpt = features.Image(torch.rand(1, 3, 32, 32)) + expected = mocker.MagicMock(spec=features.Image) + expected.num_channels = 3 + if isinstance(padding, int): + expected.image_size = (inpt.image_size[0] + padding, inpt.image_size[1] + padding) + elif isinstance(padding, list): + expected.image_size = ( + inpt.image_size[0] + sum(padding[0::2]), + inpt.image_size[1] + sum(padding[1::2]), + ) + else: + expected.image_size = inpt.image_size + _ = mocker.patch("torchvision.prototype.transforms.functional.pad", return_value=expected) + fn_crop = mocker.patch("torchvision.prototype.transforms.functional.crop") + + # vfdev-5, Feature Request: let's store params as Transform attribute + # This could be also helpful for users + torch.manual_seed(12) + _ = transform(inpt) + torch.manual_seed(12) + if padding is None and not pad_if_needed: + params = transform._get_params(inpt) + fn_crop.assert_called_once_with( + inpt, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] + ) + elif not pad_if_needed: + params = transform._get_params(expected) + fn_crop.assert_called_once_with( + expected, top=params["top"], left=params["left"], height=output_size[0], width=output_size[1] + ) + elif padding is None: + # vfdev-5: I do not know how to mock and test this case + pass + else: + # vfdev-5: I do not know how to mock and test this case + pass diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 2075ea7c52b..db1d006336f 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -10,6 +10,7 @@ Resize, CenterCrop, RandomResizedCrop, + RandomCrop, FiveCrop, TenCrop, BatchMultiCrop, diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index fd14ac0296b..88a118dbc9a 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -35,7 +35,8 @@ def __init__( antialias: Optional[bool] = None, ) -> None: super().__init__() - self.size = [size] if isinstance(size, int) else list(size) + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") self.interpolation = interpolation self.max_size = max_size self.antialias = antialias @@ -80,7 +81,6 @@ def __init__( if (scale[0] > scale[1]) or (ratio[0] > ratio[1]): warnings.warn("Scale and ratio should be of kind (min, max)") - self.size = size self.scale = scale self.ratio = ratio self.interpolation = interpolation @@ -225,6 +225,19 @@ def _check_fill_arg(fill: Union[int, float, Sequence[int], Sequence[float]]) -> raise TypeError("Got inappropriate fill arg") +def _check_padding_arg(padding: Union[int, Sequence[int]]) -> None: + if not isinstance(padding, (numbers.Number, tuple, list)): + raise TypeError("Got inappropriate padding arg") + + if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: + raise ValueError(f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple") + + +def _check_padding_mode_arg(padding_mode: Literal["constant", "edge", "reflect", "symmetric"]) -> None: + if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: + raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + + class Pad(Transform): def __init__( self, @@ -233,18 +246,10 @@ def __init__( padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", ) -> None: super().__init__() - if not isinstance(padding, (numbers.Number, tuple, list)): - raise TypeError("Got inappropriate padding arg") - - if isinstance(padding, (tuple, list)) and len(padding) not in [1, 2, 4]: - raise ValueError( - f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple" - ) + _check_padding_arg(padding) _check_fill_arg(fill) - - if padding_mode not in ["constant", "edge", "reflect", "symmetric"]: - raise ValueError("Padding mode should be either constant, edge, reflect or symmetric") + _check_padding_mode_arg(padding_mode) self.padding = padding self.fill = fill @@ -416,3 +421,75 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: fill=self.fill, center=self.center, ) + + +class RandomCrop(Transform): + def __init__( + self, + size: Union[int, Sequence[int]], + padding: Optional[Union[int, Sequence[int]]] = None, + pad_if_needed: bool = False, + fill: Union[int, float, Sequence[int], Sequence[float]] = 0, + padding_mode: Literal["constant", "edge", "reflect", "symmetric"] = "constant", + ) -> None: + super().__init__() + + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + if padding is not None: + _check_padding_arg(padding) + + if (padding is not None) or pad_if_needed: + _check_padding_mode_arg(padding_mode) + _check_fill_arg(fill) + + self.padding = padding + self.pad_if_needed = pad_if_needed + self.fill = fill + self.padding_mode = padding_mode + + def _get_params(self, sample: Any) -> Dict[str, Any]: + image = query_image(sample) + _, height, width = get_image_dimensions(image) + output_height, output_width = self.size + + if height + 1 < output_height or width + 1 < output_width: + raise ValueError( + f"Required crop size {(output_height, output_width)} is larger then input image size {(height, width)}" + ) + + if width == output_width and height == output_height: + return dict(top=0, left=0, height=height, width=width) + + top = torch.randint(0, height - output_height + 1, size=(1,)).item() + left = torch.randint(0, width - output_width + 1, size=(1,)).item() + return dict(top=top, left=left, height=output_height, width=output_width) + + def _forward(self, flat_inputs: List[Any]) -> List[Any]: + if self.padding is not None: + flat_inputs = [F.pad(flat_input, self.padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + + image = query_image(flat_inputs) + _, height, width = get_image_dimensions(image) + + # pad the width if needed + if self.pad_if_needed and width < self.size[1]: + padding = [self.size[1] - width, 0] + flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + # pad the height if needed + if self.pad_if_needed and height < self.size[0]: + padding = [0, self.size[0] - height] + flat_inputs = [F.pad(flat_input, padding, self.fill, self.padding_mode) for flat_input in flat_inputs] + + params = self._get_params(flat_inputs) + + return [F.crop(flat_input, **params) for flat_input in flat_inputs] + + def forward(self, *inputs: Any) -> Any: + from torch.utils._pytree import tree_flatten, tree_unflatten + + sample = inputs if len(inputs) > 1 else inputs[0] + + flat_inputs, spec = tree_flatten(sample) + out_flat_inputs = self._forward(flat_inputs) + return tree_unflatten(out_flat_inputs, spec) diff --git a/torchvision/prototype/transforms/_utils.py b/torchvision/prototype/transforms/_utils.py index 0517757a758..c41ef294975 100644 --- a/torchvision/prototype/transforms/_utils.py +++ b/torchvision/prototype/transforms/_utils.py @@ -2,6 +2,7 @@ import PIL.Image import torch +from torch.utils._pytree import tree_flatten from torchvision.prototype import features from torchvision.prototype.utils._internal import query_recursively @@ -9,10 +10,20 @@ def query_image(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: + flat_sample, _ = tree_flatten(sample) + for i in flat_sample: + if type(i) == torch.Tensor or isinstance(i, (PIL.Image.Image, features.Image)): + return i + + raise TypeError("No image was found in the sample") + + +# vfdev-5: let's use tree_flatten instead of query_recursively and internal fn to make the code simplier +def query_image_(sample: Any) -> Union[PIL.Image.Image, torch.Tensor, features.Image]: def fn( id: Tuple[Any, ...], input: Any ) -> Optional[Tuple[Tuple[Any, ...], Union[PIL.Image.Image, torch.Tensor, features.Image]]]: - if type(input) in {torch.Tensor, features.Image} or isinstance(input, PIL.Image.Image): + if type(input) == torch.Tensor or isinstance(input, (PIL.Image.Image, features.Image)): return id, input return None