@@ -755,10 +755,11 @@ def test_randaug(self, inpt, interpolation, mocker):
755
755
v2_transforms .InterpolationMode .BILINEAR ,
756
756
],
757
757
)
758
- def test_randaug_jit (self , interpolation ):
758
+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
759
+ def test_randaug_jit (self , interpolation , fill ):
759
760
inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
760
- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
761
- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
761
+ t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762
+ t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762
763
763
764
tt_ref = torch .jit .script (t_ref )
764
765
tt = torch .jit .script (t )
@@ -830,10 +831,11 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
830
831
v2_transforms .InterpolationMode .BILINEAR ,
831
832
],
832
833
)
833
- def test_trivial_aug_jit (self , interpolation ):
834
+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
835
+ def test_trivial_aug_jit (self , interpolation , fill ):
834
836
inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
835
- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
836
- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
837
+ t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
838
+ t = v2_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
837
839
838
840
tt_ref = torch .jit .script (t_ref )
839
841
tt = torch .jit .script (t )
@@ -906,11 +908,12 @@ def test_augmix(self, inpt, interpolation, mocker):
906
908
v2_transforms .InterpolationMode .BILINEAR ,
907
909
],
908
910
)
909
- def test_augmix_jit (self , interpolation ):
911
+ @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
912
+ def test_augmix_jit (self , interpolation , fill ):
910
913
inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
911
914
912
- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
913
- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
915
+ t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
916
+ t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
914
917
915
918
tt_ref = torch .jit .script (t_ref )
916
919
tt = torch .jit .script (t )
0 commit comments