Skip to content

Commit dff382f

Browse files
committed
Added tests for ElasticTransform
1 parent 2099d00 commit dff382f

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

test/test_prototype_transforms.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker):
397397
fn = mocker.patch("torchvision.prototype.transforms.functional.pad")
398398
# vfdev-5, Feature Request: let's store params as Transform attribute
399399
# This could be also helpful for users
400+
# Otherwise, we can mock transform._get_params
400401
torch.manual_seed(12)
401402
_ = transform(inpt)
402403
torch.manual_seed(12)
@@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
456457
inpt = mocker.MagicMock(spec=features.Image)
457458
# vfdev-5, Feature Request: let's store params as Transform attribute
458459
# This could be also helpful for users
460+
# Otherwise, we can mock transform._get_params
459461
torch.manual_seed(12)
460462
_ = transform(inpt)
461463
torch.manual_seed(12)
@@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
576578

577579
# vfdev-5, Feature Request: let's store params as Transform attribute
578580
# This could be also helpful for users
581+
# Otherwise, we can mock transform._get_params
579582
torch.manual_seed(12)
580583
_ = transform(inpt)
581584
torch.manual_seed(12)
@@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
645648

646649
# vfdev-5, Feature Request: let's store params as Transform attribute
647650
# This could be also helpful for users
651+
# Otherwise, we can mock transform._get_params
648652
torch.manual_seed(12)
649653
_ = transform(inpt)
650654
torch.manual_seed(12)
@@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker):
716720

717721
# vfdev-5, Feature Request: let's store params as Transform attribute
718722
# This could be also helpful for users
723+
# Otherwise, we can mock transform._get_params
719724
torch.manual_seed(12)
720725
_ = transform(inpt)
721726
torch.manual_seed(12)
@@ -795,10 +800,80 @@ def test__transform(self, distortion_scale, mocker):
795800
inpt.image_size = (24, 32)
796801
# vfdev-5, Feature Request: let's store params as Transform attribute
797802
# This could be also helpful for users
803+
# Otherwise, we can mock transform._get_params
798804
torch.manual_seed(12)
799805
_ = transform(inpt)
800806
torch.manual_seed(12)
801807
torch.rand(1) # random apply changes random state
802808
params = transform._get_params(inpt)
803809

804810
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)
811+
812+
813+
class TestElasticTransform:
814+
def test_assertions(self):
815+
816+
with pytest.raises(TypeError, match="alpha should be float or a sequence of floats"):
817+
transforms.ElasticTransform({})
818+
819+
with pytest.raises(ValueError, match=f"alpha is a sequence its length should be one of 2"):
820+
transforms.ElasticTransform([1.0, 2.0, 3.0])
821+
822+
with pytest.raises(ValueError, match=f"alpha should be a sequence of floats"):
823+
transforms.ElasticTransform([1, 2])
824+
825+
with pytest.raises(TypeError, match="sigma should be float or a sequence of floats"):
826+
transforms.ElasticTransform(1.0, {})
827+
828+
with pytest.raises(ValueError, match=f"sigma is a sequence its length should be one of 2"):
829+
transforms.ElasticTransform(1.0, [1.0, 2.0, 3.0])
830+
831+
with pytest.raises(ValueError, match=f"sigma should be a sequence of floats"):
832+
transforms.ElasticTransform(1.0, [1, 2])
833+
834+
with pytest.raises(TypeError, match="Got inappropriate fill arg"):
835+
transforms.ElasticTransform(1.0, 2.0, fill="abc")
836+
837+
def test__get_params(self, mocker):
838+
alpha = 2.0
839+
sigma = 3.0
840+
transform = transforms.ElasticTransform(alpha, sigma)
841+
image = mocker.MagicMock(spec=features.Image)
842+
image.num_channels = 3
843+
image.image_size = (24, 32)
844+
845+
params = transform._get_params(image)
846+
847+
h, w = image.image_size
848+
displacement = params["displacement"]
849+
assert displacement.shape == (1, h, w, 2)
850+
assert (-alpha / w <= displacement[0, ..., 0]).all() and (displacement[0, ..., 0] <= alpha / w).all()
851+
assert (-alpha / h <= displacement[0, ..., 1]).all() and (displacement[0, ..., 1] <= alpha / h).all()
852+
853+
@pytest.mark.parametrize("alpha", [5.0, [5.0, 10.0]])
854+
@pytest.mark.parametrize("sigma", [2.0, [2.0, 5.0]])
855+
def test__transform(self, alpha, sigma, mocker):
856+
interpolation = InterpolationMode.BILINEAR
857+
fill = 12
858+
transform = transforms.ElasticTransform(alpha, sigma=sigma, fill=fill, interpolation=interpolation)
859+
860+
if isinstance(alpha, float):
861+
assert transform.alpha == [alpha, alpha]
862+
else:
863+
assert transform.alpha == alpha
864+
865+
if isinstance(sigma, float):
866+
assert transform.sigma == [sigma, sigma]
867+
else:
868+
assert transform.sigma == sigma
869+
870+
fn = mocker.patch("torchvision.prototype.transforms.functional.elastic")
871+
inpt = mocker.MagicMock(spec=features.Image)
872+
inpt.num_channels = 3
873+
inpt.image_size = (24, 32)
874+
875+
# Let's mock transform._get_params to control the output:
876+
transform._get_params = mocker.MagicMock()
877+
_ = transform(inpt)
878+
params = transform._get_params(inpt)
879+
fn.assert_called_once_with(inpt, **params, fill=fill, interpolation=interpolation)

test/test_prototype_transforms_functional.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1824,7 +1824,6 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):
18241824
displacement = torch.cat([d1[..., None], d2[..., None]], dim=-1)
18251825
displacement = displacement.reshape(1, h, w, 2)
18261826

1827-
print(sample.dtype, sample.shape)
18281827
output = fn(sample, displacement=displacement, **kwargs)
18291828

18301829
# Check places where transformed points should be

torchvision/prototype/transforms/_geometry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def _setup_float_or_seq(arg: Union[float, Sequence[float]], name: str, req_size:
561561
if isinstance(arg, Sequence):
562562
for element in arg:
563563
if not isinstance(element, float):
564-
raise TypeError(f"{name} should be a sequence of floats. Got {type(element)}")
564+
raise ValueError(f"{name} should be a sequence of floats. Got {type(element)}")
565565

566566
if isinstance(arg, float):
567567
arg = [float(arg), float(arg)]

0 commit comments

Comments
 (0)