@@ -397,6 +397,7 @@ def test__transform(self, fill, side_range, mocker):
397
397
fn = mocker .patch ("torchvision.prototype.transforms.functional.pad" )
398
398
# vfdev-5, Feature Request: let's store params as Transform attribute
399
399
# This could be also helpful for users
400
+ # Otherwise, we can mock transform._get_params
400
401
torch .manual_seed (12 )
401
402
_ = transform (inpt )
402
403
torch .manual_seed (12 )
@@ -456,6 +457,7 @@ def test__transform(self, degrees, expand, fill, center, mocker):
456
457
inpt = mocker .MagicMock (spec = features .Image )
457
458
# vfdev-5, Feature Request: let's store params as Transform attribute
458
459
# This could be also helpful for users
460
+ # Otherwise, we can mock transform._get_params
459
461
torch .manual_seed (12 )
460
462
_ = transform (inpt )
461
463
torch .manual_seed (12 )
@@ -576,6 +578,7 @@ def test__transform(self, degrees, translate, scale, shear, fill, center, mocker
576
578
577
579
# vfdev-5, Feature Request: let's store params as Transform attribute
578
580
# This could be also helpful for users
581
+ # Otherwise, we can mock transform._get_params
579
582
torch .manual_seed (12 )
580
583
_ = transform (inpt )
581
584
torch .manual_seed (12 )
@@ -645,6 +648,7 @@ def test_forward(self, padding, pad_if_needed, fill, padding_mode, mocker):
645
648
646
649
# vfdev-5, Feature Request: let's store params as Transform attribute
647
650
# This could be also helpful for users
651
+ # Otherwise, we can mock transform._get_params
648
652
torch .manual_seed (12 )
649
653
_ = transform (inpt )
650
654
torch .manual_seed (12 )
@@ -716,6 +720,7 @@ def test__transform(self, kernel_size, sigma, mocker):
716
720
717
721
# vfdev-5, Feature Request: let's store params as Transform attribute
718
722
# This could be also helpful for users
723
+ # Otherwise, we can mock transform._get_params
719
724
torch .manual_seed (12 )
720
725
_ = transform (inpt )
721
726
torch .manual_seed (12 )
@@ -795,10 +800,80 @@ def test__transform(self, distortion_scale, mocker):
795
800
inpt .image_size = (24 , 32 )
796
801
# vfdev-5, Feature Request: let's store params as Transform attribute
797
802
# This could be also helpful for users
803
+ # Otherwise, we can mock transform._get_params
798
804
torch .manual_seed (12 )
799
805
_ = transform (inpt )
800
806
torch .manual_seed (12 )
801
807
torch .rand (1 ) # random apply changes random state
802
808
params = transform ._get_params (inpt )
803
809
804
810
fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
811
+
812
+
813
+ class TestElasticTransform :
814
+ def test_assertions (self ):
815
+
816
+ with pytest .raises (TypeError , match = "alpha should be float or a sequence of floats" ):
817
+ transforms .ElasticTransform ({})
818
+
819
+ with pytest .raises (ValueError , match = f"alpha is a sequence its length should be one of 2" ):
820
+ transforms .ElasticTransform ([1.0 , 2.0 , 3.0 ])
821
+
822
+ with pytest .raises (ValueError , match = f"alpha should be a sequence of floats" ):
823
+ transforms .ElasticTransform ([1 , 2 ])
824
+
825
+ with pytest .raises (TypeError , match = "sigma should be float or a sequence of floats" ):
826
+ transforms .ElasticTransform (1.0 , {})
827
+
828
+ with pytest .raises (ValueError , match = f"sigma is a sequence its length should be one of 2" ):
829
+ transforms .ElasticTransform (1.0 , [1.0 , 2.0 , 3.0 ])
830
+
831
+ with pytest .raises (ValueError , match = f"sigma should be a sequence of floats" ):
832
+ transforms .ElasticTransform (1.0 , [1 , 2 ])
833
+
834
+ with pytest .raises (TypeError , match = "Got inappropriate fill arg" ):
835
+ transforms .ElasticTransform (1.0 , 2.0 , fill = "abc" )
836
+
837
+ def test__get_params (self , mocker ):
838
+ alpha = 2.0
839
+ sigma = 3.0
840
+ transform = transforms .ElasticTransform (alpha , sigma )
841
+ image = mocker .MagicMock (spec = features .Image )
842
+ image .num_channels = 3
843
+ image .image_size = (24 , 32 )
844
+
845
+ params = transform ._get_params (image )
846
+
847
+ h , w = image .image_size
848
+ displacement = params ["displacement" ]
849
+ assert displacement .shape == (1 , h , w , 2 )
850
+ assert (- alpha / w <= displacement [0 , ..., 0 ]).all () and (displacement [0 , ..., 0 ] <= alpha / w ).all ()
851
+ assert (- alpha / h <= displacement [0 , ..., 1 ]).all () and (displacement [0 , ..., 1 ] <= alpha / h ).all ()
852
+
853
+ @pytest .mark .parametrize ("alpha" , [5.0 , [5.0 , 10.0 ]])
854
+ @pytest .mark .parametrize ("sigma" , [2.0 , [2.0 , 5.0 ]])
855
+ def test__transform (self , alpha , sigma , mocker ):
856
+ interpolation = InterpolationMode .BILINEAR
857
+ fill = 12
858
+ transform = transforms .ElasticTransform (alpha , sigma = sigma , fill = fill , interpolation = interpolation )
859
+
860
+ if isinstance (alpha , float ):
861
+ assert transform .alpha == [alpha , alpha ]
862
+ else :
863
+ assert transform .alpha == alpha
864
+
865
+ if isinstance (sigma , float ):
866
+ assert transform .sigma == [sigma , sigma ]
867
+ else :
868
+ assert transform .sigma == sigma
869
+
870
+ fn = mocker .patch ("torchvision.prototype.transforms.functional.elastic" )
871
+ inpt = mocker .MagicMock (spec = features .Image )
872
+ inpt .num_channels = 3
873
+ inpt .image_size = (24 , 32 )
874
+
875
+ # Let's mock transform._get_params to control the output:
876
+ transform ._get_params = mocker .MagicMock ()
877
+ _ = transform (inpt )
878
+ params = transform ._get_params (inpt )
879
+ fn .assert_called_once_with (inpt , ** params , fill = fill , interpolation = interpolation )
0 commit comments