@@ -927,6 +927,29 @@ def test_randaug(self, inpt, interpolation, mocker):
927
927
928
928
assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
929
929
930
+ @pytest .mark .parametrize (
931
+ "interpolation" ,
932
+ [
933
+ v2_transforms .InterpolationMode .NEAREST ,
934
+ v2_transforms .InterpolationMode .BILINEAR ,
935
+ ],
936
+ )
937
+ def test_randaug_jit (self , interpolation ):
938
+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
939
+ t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
940
+ t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
941
+
942
+ tt_ref = torch .jit .script (t_ref )
943
+ tt = torch .jit .script (t )
944
+
945
+ torch .manual_seed (12 )
946
+ expected_output = tt_ref (inpt )
947
+
948
+ torch .manual_seed (12 )
949
+ scripted_output = tt (inpt )
950
+
951
+ assert_equal (scripted_output , expected_output )
952
+
930
953
@pytest .mark .parametrize (
931
954
"inpt" ,
932
955
[
@@ -979,6 +1002,29 @@ def test_trivial_aug(self, inpt, interpolation, mocker):
979
1002
980
1003
assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
981
1004
1005
+ @pytest .mark .parametrize (
1006
+ "interpolation" ,
1007
+ [
1008
+ v2_transforms .InterpolationMode .NEAREST ,
1009
+ v2_transforms .InterpolationMode .BILINEAR ,
1010
+ ],
1011
+ )
1012
+ def test_trivial_aug_jit (self , interpolation ):
1013
+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1014
+ t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
1015
+ t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
1016
+
1017
+ tt_ref = torch .jit .script (t_ref )
1018
+ tt = torch .jit .script (t )
1019
+
1020
+ torch .manual_seed (12 )
1021
+ expected_output = tt_ref (inpt )
1022
+
1023
+ torch .manual_seed (12 )
1024
+ scripted_output = tt (inpt )
1025
+
1026
+ assert_equal (scripted_output , expected_output )
1027
+
982
1028
@pytest .mark .parametrize (
983
1029
"inpt" ,
984
1030
[
@@ -1032,6 +1078,30 @@ def test_augmix(self, inpt, interpolation, mocker):
1032
1078
1033
1079
assert_equal (expected_output , output )
1034
1080
1081
+ @pytest .mark .parametrize (
1082
+ "interpolation" ,
1083
+ [
1084
+ v2_transforms .InterpolationMode .NEAREST ,
1085
+ v2_transforms .InterpolationMode .BILINEAR ,
1086
+ ],
1087
+ )
1088
+ def test_augmix_jit (self , interpolation ):
1089
+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1090
+
1091
+ t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
1092
+ t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
1093
+
1094
+ tt_ref = torch .jit .script (t_ref )
1095
+ tt = torch .jit .script (t )
1096
+
1097
+ torch .manual_seed (12 )
1098
+ expected_output = tt_ref (inpt )
1099
+
1100
+ torch .manual_seed (12 )
1101
+ scripted_output = tt (inpt )
1102
+
1103
+ assert_equal (scripted_output , expected_output )
1104
+
1035
1105
@pytest .mark .parametrize (
1036
1106
"inpt" ,
1037
1107
[
@@ -1061,6 +1131,30 @@ def test_aa(self, inpt, interpolation):
1061
1131
1062
1132
assert_equal (expected_output , output )
1063
1133
1134
+ @pytest .mark .parametrize (
1135
+ "interpolation" ,
1136
+ [
1137
+ v2_transforms .InterpolationMode .NEAREST ,
1138
+ v2_transforms .InterpolationMode .BILINEAR ,
1139
+ ],
1140
+ )
1141
+ def test_aa_jit (self , interpolation ):
1142
+ inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
1143
+ aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
1144
+ t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
1145
+ t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
1146
+
1147
+ tt_ref = torch .jit .script (t_ref )
1148
+ tt = torch .jit .script (t )
1149
+
1150
+ torch .manual_seed (12 )
1151
+ expected_output = tt_ref (inpt )
1152
+
1153
+ torch .manual_seed (12 )
1154
+ scripted_output = tt (inpt )
1155
+
1156
+ assert_equal (scripted_output , expected_output )
1157
+
1064
1158
1065
1159
def import_transforms_from_references (reference ):
1066
1160
HERE = Path (__file__ ).parent
0 commit comments