@@ -927,6 +927,29 @@ def test_randaug(self, inpt, interpolation, mocker):
927927
928928 assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
929929
930+ @pytest .mark .parametrize (
931+ "interpolation" ,
932+ [
933+ v2_transforms .InterpolationMode .NEAREST ,
934+ v2_transforms .InterpolationMode .BILINEAR ,
935+ ],
936+ )
937+ def test_randaug_jit (self , interpolation ):
938+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
939+ t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
940+ t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
941+
942+ tt_ref = torch .jit .script (t_ref )
943+ tt = torch .jit .script (t )
944+
945+ torch .manual_seed (12 )
946+ expected_output = tt_ref (inpt )
947+
948+ torch .manual_seed (12 )
949+ scripted_output = tt (inpt )
950+
951+ assert_equal (scripted_output , expected_output )
952+
930953 @pytest .mark .parametrize (
931954 "inpt" ,
932955 [
@@ -979,6 +1002,29 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
9791002
9801003 assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
9811004
1005+ @pytest .mark .parametrize (
1006+ "interpolation" ,
1007+ [
1008+ v2_transforms .InterpolationMode .NEAREST ,
1009+ v2_transforms .InterpolationMode .BILINEAR ,
1010+ ],
1011+ )
1012+ def test_trivial_aug_jit (self , interpolation ):
1013+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1014+ t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
1015+ t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
1016+
1017+ tt_ref = torch .jit .script (t_ref )
1018+ tt = torch .jit .script (t )
1019+
1020+ torch .manual_seed (12 )
1021+ expected_output = tt_ref (inpt )
1022+
1023+ torch .manual_seed (12 )
1024+ scripted_output = tt (inpt )
1025+
1026+ assert_equal (scripted_output , expected_output )
1027+
9821028 @pytest .mark .parametrize (
9831029 "inpt" ,
9841030 [
@@ -1032,6 +1078,30 @@ def test_augmix(self, inpt, interpolation, mocker):
10321078
10331079 assert_equal (expected_output , output )
10341080
1081+ @pytest .mark .parametrize (
1082+ "interpolation" ,
1083+ [
1084+ v2_transforms .InterpolationMode .NEAREST ,
1085+ v2_transforms .InterpolationMode .BILINEAR ,
1086+ ],
1087+ )
1088+ def test_augmix_jit (self , interpolation ):
1089+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1090+
1091+ t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
1092+ t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
1093+
1094+ tt_ref = torch .jit .script (t_ref )
1095+ tt = torch .jit .script (t )
1096+
1097+ torch .manual_seed (12 )
1098+ expected_output = tt_ref (inpt )
1099+
1100+ torch .manual_seed (12 )
1101+ scripted_output = tt (inpt )
1102+
1103+ assert_equal (scripted_output , expected_output )
1104+
10351105 @pytest .mark .parametrize (
10361106 "inpt" ,
10371107 [
@@ -1061,6 +1131,30 @@ def test_aa(self, inpt, interpolation):
10611131
10621132 assert_equal (expected_output , output )
10631133
1134+ @pytest .mark .parametrize (
1135+ "interpolation" ,
1136+ [
1137+ v2_transforms .InterpolationMode .NEAREST ,
1138+ v2_transforms .InterpolationMode .BILINEAR ,
1139+ ],
1140+ )
1141+ def test_aa_jit (self , interpolation ):
1142+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1143+ aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
1144+ t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
1145+ t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
1146+
1147+ tt_ref = torch .jit .script (t_ref )
1148+ tt = torch .jit .script (t )
1149+
1150+ torch .manual_seed (12 )
1151+ expected_output = tt_ref (inpt )
1152+
1153+ torch .manual_seed (12 )
1154+ scripted_output = tt (inpt )
1155+
1156+ assert_equal (scripted_output , expected_output )
1157+
10641158
10651159def import_transforms_from_references (reference ):
10661160 HERE = Path (__file__ ).parent
0 commit comments