diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 27c2e1d581c..b21a3c62878 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -1,5 +1,7 @@ import itertools +import PIL.Image + import pytest import torch from common_utils import assert_equal @@ -879,3 +881,99 @@ def test__transform(self, alpha, sigma, mocker): _ = transform(inpt) params = transform._get_params(inpt) fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation) + + +class TestRandomErasing: + def test_assertions(self, mocker): + with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): + transforms.RandomErasing(value={}) + + with pytest.raises(ValueError, match="If value is str, it should be 'random'"): + transforms.RandomErasing(value="abc") + + with pytest.raises(TypeError, match="Scale should be a sequence"): + transforms.RandomErasing(scale=123) + + with pytest.raises(TypeError, match="Ratio should be a sequence"): + transforms.RandomErasing(ratio=123) + + with pytest.raises(ValueError, match="Scale should be between 0 and 1"): + transforms.RandomErasing(scale=[-1, 2]) + + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + transform = transforms.RandomErasing(value=[1, 2, 3, 4]) + + with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): + transform._get_params(image) + + @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) + def test__get_params(self, value, mocker): + image = mocker.MagicMock(spec=features.Image) + image.num_channels = 3 + image.image_size = (24, 32) + + transform = transforms.RandomErasing(value=value) + params = transform._get_params(image) + + v = params["v"] + h, w = params["h"], params["w"] + i, j = params["i"], params["j"] + assert isinstance(v, torch.Tensor) + if value == "random": + assert v.shape == (image.num_channels, h, w) + elif isinstance(value, (int, float)): + assert v.shape == (1, 1, 1) + elif isinstance(value, (list, tuple)): + assert v.shape == (image.num_channels, 1, 1) + + assert 0 <= i <= image.image_size[0] - h + assert 0 <= j <= image.image_size[1] - w + + @pytest.mark.parametrize("p", [0.0, 1.0]) + @pytest.mark.parametrize( + "inpt_type", + [ + (torch.Tensor, {"shape": (3, 24, 32)}), + (PIL.Image.Image, {"size": (24, 32), "mode": "RGB"}), + ], + ) + def test__transform(self, p, inpt_type, mocker): + value = 1.0 + transform = transforms.RandomErasing(p=p, value=value) + + inpt = mocker.MagicMock(spec=inpt_type[0], **inpt_type[1]) + erase_image_tensor_inpt = inpt + fn = mocker.patch( + "torchvision.prototype.transforms.functional.erase_image_tensor", + return_value=mocker.MagicMock(spec=torch.Tensor), + ) + if inpt_type[0] == PIL.Image.Image: + erase_image_tensor_inpt = mocker.MagicMock(spec=torch.Tensor) + + # vfdev-5: I do not know how to patch pil_to_tensor if it is already imported + # TODO: patch pil_to_tensor and run below checks for PIL.Image.Image inputs + if p > 0.0: + return + + mocker.patch( + "torchvision.transforms.functional.pil_to_tensor", + return_value=erase_image_tensor_inpt, + ) + mocker.patch( + "torchvision.transforms.functional.to_pil_image", + return_value=mocker.MagicMock(spec=PIL.Image.Image), + ) + + # Let's mock transform._get_params to control the output: + transform._get_params = mocker.MagicMock() + output = transform(inpt) + print(inpt_type) + assert isinstance(output, inpt_type[0]) + params = transform._get_params(inpt) + if p > 0.0: + fn.assert_called_once_with(erase_image_tensor_inpt, **params) + else: + fn.call_count == 0 diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index d1c3db816ad..2c71a5faf64 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -7,6 +7,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import functional as F, Transform +from torchvision.transforms.functional import pil_to_tensor, to_pil_image from ._transform import _RandomApplyTransform from ._utils import get_image_dimensions, has_all, has_any, query_image @@ -92,8 +93,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: return features.Image.new_like(inpt, output) return output elif isinstance(inpt, PIL.Image.Image): - # TODO: We should implement a fallback to tensor, like gaussian_blur etc - raise RuntimeError("Not implemented") + t_img = pil_to_tensor(inpt) + output = F.erase_image_tensor(t_img, **params) + return to_pil_image(output, mode=inpt.mode) else: return inpt