|
1 | 1 | import itertools
|
2 | 2 |
|
| 3 | +import PIL.Image |
| 4 | + |
3 | 5 | import pytest
|
4 | 6 | import torch
|
5 | 7 | from common_utils import assert_equal
|
@@ -879,3 +881,99 @@ def test__transform(self, alpha, sigma, mocker):
|
879 | 881 | _ = transform(inpt)
|
880 | 882 | params = transform._get_params(inpt)
|
881 | 883 | fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
|
| 884 | + |
| 885 | + |
| 886 | +class TestRandomErasing: |
| 887 | + def test_assertions(self, mocker): |
| 888 | + with pytest.raises(TypeError, match="Argument value should be either a number or str or a sequence"): |
| 889 | + transforms.RandomErasing(value={}) |
| 890 | + |
| 891 | + with pytest.raises(ValueError, match="If value is str, it should be 'random'"): |
| 892 | + transforms.RandomErasing(value="abc") |
| 893 | + |
| 894 | + with pytest.raises(TypeError, match="Scale should be a sequence"): |
| 895 | + transforms.RandomErasing(scale=123) |
| 896 | + |
| 897 | + with pytest.raises(TypeError, match="Ratio should be a sequence"): |
| 898 | + transforms.RandomErasing(ratio=123) |
| 899 | + |
| 900 | + with pytest.raises(ValueError, match="Scale should be between 0 and 1"): |
| 901 | + transforms.RandomErasing(scale=[-1, 2]) |
| 902 | + |
| 903 | + image = mocker.MagicMock(spec=features.Image) |
| 904 | + image.num_channels = 3 |
| 905 | + image.image_size = (24, 32) |
| 906 | + |
| 907 | + transform = transforms.RandomErasing(value=[1, 2, 3, 4]) |
| 908 | + |
| 909 | + with pytest.raises(ValueError, match="If value is a sequence, it should have either a single value"): |
| 910 | + transform._get_params(image) |
| 911 | + |
| 912 | + @pytest.mark.parametrize("value", [5.0, [1, 2, 3], "random"]) |
| 913 | + def test__get_params(self, value, mocker): |
| 914 | + image = mocker.MagicMock(spec=features.Image) |
| 915 | + image.num_channels = 3 |
| 916 | + image.image_size = (24, 32) |
| 917 | + |
| 918 | + transform = transforms.RandomErasing(value=value) |
| 919 | + params = transform._get_params(image) |
| 920 | + |
| 921 | + v = params["v"] |
| 922 | + h, w = params["h"], params["w"] |
| 923 | + i, j = params["i"], params["j"] |
| 924 | + assert isinstance(v, torch.Tensor) |
| 925 | + if value == "random": |
| 926 | + assert v.shape == (image.num_channels, h, w) |
| 927 | + elif isinstance(value, (int, float)): |
| 928 | + assert v.shape == (1, 1, 1) |
| 929 | + elif isinstance(value, (list, tuple)): |
| 930 | + assert v.shape == (image.num_channels, 1, 1) |
| 931 | + |
| 932 | + assert 0 <= i <= image.image_size[0] - h |
| 933 | + assert 0 <= j <= image.image_size[1] - w |
| 934 | + |
| 935 | + @pytest.mark.parametrize("p", [0.0, 1.0]) |
| 936 | + @pytest.mark.parametrize( |
| 937 | + "inpt_type", |
| 938 | + [ |
| 939 | + (torch.Tensor, {"shape": (3, 24, 32)}), |
| 940 | + (PIL.Image.Image, {"size": (24, 32), "mode": "RGB"}), |
| 941 | + ], |
| 942 | + ) |
| 943 | + def test__transform(self, p, inpt_type, mocker): |
| 944 | + value = 1.0 |
| 945 | + transform = transforms.RandomErasing(p=p, value=value) |
| 946 | + |
| 947 | + inpt = mocker.MagicMock(spec=inpt_type[0], **inpt_type[1]) |
| 948 | + erase_image_tensor_inpt = inpt |
| 949 | + fn = mocker.patch( |
| 950 | + "torchvision.prototype.transforms.functional.erase_image_tensor", |
| 951 | + return_value=mocker.MagicMock(spec=torch.Tensor), |
| 952 | + ) |
| 953 | + if inpt_type[0] == PIL.Image.Image: |
| 954 | + erase_image_tensor_inpt = mocker.MagicMock(spec=torch.Tensor) |
| 955 | + |
| 956 | + # vfdev-5: I do not know how to patch pil_to_tensor if it is already imported |
| 957 | + # TODO: patch pil_to_tensor and run below checks for PIL.Image.Image inputs |
| 958 | + if p > 0.0: |
| 959 | + return |
| 960 | + |
| 961 | + mocker.patch( |
| 962 | + "torchvision.transforms.functional.pil_to_tensor", |
| 963 | + return_value=erase_image_tensor_inpt, |
| 964 | + ) |
| 965 | + mocker.patch( |
| 966 | + "torchvision.transforms.functional.to_pil_image", |
| 967 | + return_value=mocker.MagicMock(spec=PIL.Image.Image), |
| 968 | + ) |
| 969 | + |
| 970 | + # Let's mock transform._get_params to control the output: |
| 971 | + transform._get_params = mocker.MagicMock() |
| 972 | + output = transform(inpt) |
| 973 | + print(inpt_type) |
| 974 | + assert isinstance(output, inpt_type[0]) |
| 975 | + params = transform._get_params(inpt) |
| 976 | + if p > 0.0: |
| 977 | + fn.assert_called_once_with(erase_image_tensor_inpt, **params) |
| 978 | + else: |
| 979 | + fn.call_count == 0 |
0 commit comments