Skip to content

Commit 2ee8dca

Browse files
authored
Implemented RandomErase on PIL input as fallback to tensors (#6309)
Added tests
1 parent 14d221d commit 2ee8dca

File tree

2 files changed

+102
-2
lines changed

2 files changed

+102
-2
lines changed

test/test_prototype_transforms.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import itertools
22

3+
import PIL.Image
4+
35
import pytest
46
import torch
57
from common_utils import assert_equal
@@ -879,3 +881,99 @@ def test__transform(self, alpha, sigma, mocker):
879881
_ = transform(inpt)
880882
params = transform._get_params(inpt)
881883
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
884+
885+
886+
class TestRandomErasing:
887+
def test_assertions(self, mocker):
888+
with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"):
889+
transforms.RandomErasing(value={})
890+
891+
with pytest.raises(ValueError, match="If value is str, it should be 'random'"):
892+
transforms.RandomErasing(value="abc")
893+
894+
with pytest.raises(TypeError, match="Scale should be a sequence"):
895+
transforms.RandomErasing(scale=123)
896+
897+
with pytest.raises(TypeError, match="Ratio should be a sequence"):
898+
transforms.RandomErasing(ratio=123)
899+
900+
with pytest.raises(ValueError, match="Scale should be between 0 and 1"):
901+
transforms.RandomErasing(scale=[-1, 2])
902+
903+
image = mocker.MagicMock(spec=features.Image)
904+
image.num_channels = 3
905+
image.image_size = (24, 32)
906+
907+
transform = transforms.RandomErasing(value=[1, 2, 3, 4])
908+
909+
with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"):
910+
transform._get_params(image)
911+
912+
@pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"])
913+
def test__get_params(self, value, mocker):
914+
image = mocker.MagicMock(spec=features.Image)
915+
image.num_channels = 3
916+
image.image_size = (24, 32)
917+
918+
transform = transforms.RandomErasing(value=value)
919+
params = transform._get_params(image)
920+
921+
v = params["v"]
922+
h, w = params["h"], params["w"]
923+
i, j = params["i"], params["j"]
924+
assert isinstance(v, torch.Tensor)
925+
if value == "random":
926+
assert v.shape == (image.num_channels, h, w)
927+
elif isinstance(value, (int, float)):
928+
assert v.shape == (1, 1, 1)
929+
elif isinstance(value, (list, tuple)):
930+
assert v.shape == (image.num_channels, 1, 1)
931+
932+
assert 0 <= i <= image.image_size[0] - h
933+
assert 0 <= j <= image.image_size[1] - w
934+
935+
@pytest.mark.parametrize("p", [0.0, 1.0])
936+
@pytest.mark.parametrize(
937+
"inpt_type",
938+
[
939+
(torch.Tensor, {"shape": (3, 24, 32)}),
940+
(PIL.Image.Image, {"size": (24, 32), "mode": "RGB"}),
941+
],
942+
)
943+
def test__transform(self, p, inpt_type, mocker):
944+
value = 1.0
945+
transform = transforms.RandomErasing(p=p, value=value)
946+
947+
inpt = mocker.MagicMock(spec=inpt_type[0], **inpt_type[1])
948+
erase_image_tensor_inpt = inpt
949+
fn = mocker.patch(
950+
"torchvision.prototype.transforms.functional.erase_image_tensor",
951+
return_value=mocker.MagicMock(spec=torch.Tensor),
952+
)
953+
if inpt_type[0] == PIL.Image.Image:
954+
erase_image_tensor_inpt = mocker.MagicMock(spec=torch.Tensor)
955+
956+
# vfdev-5: I do not know how to patch pil_to_tensor if it is already imported
957+
# TODO: patch pil_to_tensor and run below checks for PIL.Image.Image inputs
958+
if p > 0.0:
959+
return
960+
961+
mocker.patch(
962+
"torchvision.transforms.functional.pil_to_tensor",
963+
return_value=erase_image_tensor_inpt,
964+
)
965+
mocker.patch(
966+
"torchvision.transforms.functional.to_pil_image",
967+
return_value=mocker.MagicMock(spec=PIL.Image.Image),
968+
)
969+
970+
# Let's mock transform._get_params to control the output:
971+
transform._get_params = mocker.MagicMock()
972+
output = transform(inpt)
973+
print(inpt_type)
974+
assert isinstance(output, inpt_type[0])
975+
params = transform._get_params(inpt)
976+
if p > 0.0:
977+
fn.assert_called_once_with(erase_image_tensor_inpt, **params)
978+
else:
979+
fn.call_count == 0

torchvision/prototype/transforms/_augment.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
from torchvision.prototype import features
99
from torchvision.prototype.transforms import functional as F, Transform
10+
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1011

1112
from ._transform import _RandomApplyTransform
1213
from ._utils import get_image_dimensions, has_all, has_any, query_image
@@ -92,8 +93,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
9293
return features.Image.new_like(inpt, output)
9394
return output
9495
elif isinstance(inpt, PIL.Image.Image):
95-
# TODO: We should implement a fallback to tensor, like gaussian_blur etc
96-
raise RuntimeError("Not implemented")
96+
t_img = pil_to_tensor(inpt)
97+
output = F.erase_image_tensor(t_img, **params)
98+
return to_pil_image(output, mode=inpt.mode)
9799
else:
98100
return inpt
99101

0 commit comments

Comments
 (0)