@@ -705,281 +705,6 @@ def test_to_tensor(self):
705
705
assert_equal (prototype_transform (image_numpy ), legacy_transform (image_numpy ))
706
706
707
707
708
- class TestAATransforms :
709
- @pytest .mark .parametrize (
710
- "inpt" ,
711
- [
712
- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
713
- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
714
- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
715
- ],
716
- )
717
- @pytest .mark .parametrize (
718
- "interpolation" ,
719
- [
720
- v2_transforms .InterpolationMode .NEAREST ,
721
- v2_transforms .InterpolationMode .BILINEAR ,
722
- PIL .Image .NEAREST ,
723
- ],
724
- )
725
- def test_randaug (self , inpt , interpolation , mocker ):
726
- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
727
- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 )
728
-
729
- le = len (t ._AUGMENTATION_SPACE )
730
- keys = list (t ._AUGMENTATION_SPACE .keys ())
731
- randint_values = []
732
- for i in range (le ):
733
- # Stable API, op_index random call
734
- randint_values .append (i )
735
- # Stable API, if signed there is another random call
736
- if t ._AUGMENTATION_SPACE [keys [i ]][1 ]:
737
- randint_values .append (0 )
738
- # New API, _get_random_item
739
- randint_values .append (i )
740
- randint_values = iter (randint_values )
741
-
742
- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
743
- mocker .patch ("torch.rand" , return_value = 1.0 )
744
-
745
- for i in range (le ):
746
- expected_output = t_ref (inpt )
747
- output = t (inpt )
748
-
749
- assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
750
-
751
- @pytest .mark .parametrize (
752
- "interpolation" ,
753
- [
754
- v2_transforms .InterpolationMode .NEAREST ,
755
- v2_transforms .InterpolationMode .BILINEAR ,
756
- ],
757
- )
758
- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
759
- def test_randaug_jit (self , interpolation , fill ):
760
- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
761
- t_ref = legacy_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
762
- t = v2_transforms .RandAugment (interpolation = interpolation , num_ops = 1 , fill = fill )
763
-
764
- tt_ref = torch .jit .script (t_ref )
765
- tt = torch .jit .script (t )
766
-
767
- torch .manual_seed (12 )
768
- expected_output = tt_ref (inpt )
769
-
770
- torch .manual_seed (12 )
771
- scripted_output = tt (inpt )
772
-
773
- assert_equal (scripted_output , expected_output )
774
-
775
- @pytest .mark .parametrize (
776
- "inpt" ,
777
- [
778
- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
779
- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
780
- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
781
- ],
782
- )
783
- @pytest .mark .parametrize (
784
- "interpolation" ,
785
- [
786
- v2_transforms .InterpolationMode .NEAREST ,
787
- v2_transforms .InterpolationMode .BILINEAR ,
788
- PIL .Image .NEAREST ,
789
- ],
790
- )
791
- def test_trivial_aug (self , inpt , interpolation , mocker ):
792
- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation )
793
- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation )
794
-
795
- le = len (t ._AUGMENTATION_SPACE )
796
- keys = list (t ._AUGMENTATION_SPACE .keys ())
797
- randint_values = []
798
- for i in range (le ):
799
- # Stable API, op_index random call
800
- randint_values .append (i )
801
- key = keys [i ]
802
- # Stable API, random magnitude
803
- aug_op = t ._AUGMENTATION_SPACE [key ]
804
- magnitudes = aug_op [0 ](2 , 0 , 0 )
805
- if magnitudes is not None :
806
- randint_values .append (5 )
807
- # Stable API, if signed there is another random call
808
- if aug_op [1 ]:
809
- randint_values .append (0 )
810
- # New API, _get_random_item
811
- randint_values .append (i )
812
- # New API, random magnitude
813
- if magnitudes is not None :
814
- randint_values .append (5 )
815
-
816
- randint_values = iter (randint_values )
817
-
818
- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
819
- mocker .patch ("torch.rand" , return_value = 1.0 )
820
-
821
- for _ in range (le ):
822
- expected_output = t_ref (inpt )
823
- output = t (inpt )
824
-
825
- assert_close (expected_output , output , atol = 1 , rtol = 0.1 )
826
-
827
- @pytest .mark .parametrize (
828
- "interpolation" ,
829
- [
830
- v2_transforms .InterpolationMode .NEAREST ,
831
- v2_transforms .InterpolationMode .BILINEAR ,
832
- ],
833
- )
834
- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
835
- def test_trivial_aug_jit (self , interpolation , fill ):
836
- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
837
- t_ref = legacy_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
838
- t = v2_transforms .TrivialAugmentWide (interpolation = interpolation , fill = fill )
839
-
840
- tt_ref = torch .jit .script (t_ref )
841
- tt = torch .jit .script (t )
842
-
843
- torch .manual_seed (12 )
844
- expected_output = tt_ref (inpt )
845
-
846
- torch .manual_seed (12 )
847
- scripted_output = tt (inpt )
848
-
849
- assert_equal (scripted_output , expected_output )
850
-
851
- @pytest .mark .parametrize (
852
- "inpt" ,
853
- [
854
- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
855
- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
856
- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
857
- ],
858
- )
859
- @pytest .mark .parametrize (
860
- "interpolation" ,
861
- [
862
- v2_transforms .InterpolationMode .NEAREST ,
863
- v2_transforms .InterpolationMode .BILINEAR ,
864
- PIL .Image .NEAREST ,
865
- ],
866
- )
867
- def test_augmix (self , inpt , interpolation , mocker ):
868
- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
869
- t_ref ._sample_dirichlet = lambda t : t .softmax (dim = - 1 )
870
- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 )
871
- t ._sample_dirichlet = lambda t : t .softmax (dim = - 1 )
872
-
873
- le = len (t ._AUGMENTATION_SPACE )
874
- keys = list (t ._AUGMENTATION_SPACE .keys ())
875
- randint_values = []
876
- for i in range (le ):
877
- # Stable API, op_index random call
878
- randint_values .append (i )
879
- key = keys [i ]
880
- # Stable API, random magnitude
881
- aug_op = t ._AUGMENTATION_SPACE [key ]
882
- magnitudes = aug_op [0 ](2 , 0 , 0 )
883
- if magnitudes is not None :
884
- randint_values .append (5 )
885
- # Stable API, if signed there is another random call
886
- if aug_op [1 ]:
887
- randint_values .append (0 )
888
- # New API, _get_random_item
889
- randint_values .append (i )
890
- # New API, random magnitude
891
- if magnitudes is not None :
892
- randint_values .append (5 )
893
-
894
- randint_values = iter (randint_values )
895
-
896
- mocker .patch ("torch.randint" , side_effect = lambda * arg , ** kwargs : torch .tensor (next (randint_values )))
897
- mocker .patch ("torch.rand" , return_value = 1.0 )
898
-
899
- expected_output = t_ref (inpt )
900
- output = t (inpt )
901
-
902
- assert_equal (expected_output , output )
903
-
904
- @pytest .mark .parametrize (
905
- "interpolation" ,
906
- [
907
- v2_transforms .InterpolationMode .NEAREST ,
908
- v2_transforms .InterpolationMode .BILINEAR ,
909
- ],
910
- )
911
- @pytest .mark .parametrize ("fill" , [None , 85 , (10 , - 10 , 10 ), 0.7 , [0.0 , 0.0 , 0.0 ], [1 ], 1 ])
912
- def test_augmix_jit (self , interpolation , fill ):
913
- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
914
-
915
- t_ref = legacy_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
916
- t = v2_transforms .AugMix (interpolation = interpolation , mixture_width = 1 , chain_depth = 1 , fill = fill )
917
-
918
- tt_ref = torch .jit .script (t_ref )
919
- tt = torch .jit .script (t )
920
-
921
- torch .manual_seed (12 )
922
- expected_output = tt_ref (inpt )
923
-
924
- torch .manual_seed (12 )
925
- scripted_output = tt (inpt )
926
-
927
- assert_equal (scripted_output , expected_output )
928
-
929
- @pytest .mark .parametrize (
930
- "inpt" ,
931
- [
932
- torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 ),
933
- PIL .Image .new ("RGB" , (256 , 256 ), 123 ),
934
- tv_tensors .Image (torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )),
935
- ],
936
- )
937
- @pytest .mark .parametrize (
938
- "interpolation" ,
939
- [
940
- v2_transforms .InterpolationMode .NEAREST ,
941
- v2_transforms .InterpolationMode .BILINEAR ,
942
- PIL .Image .NEAREST ,
943
- ],
944
- )
945
- def test_aa (self , inpt , interpolation ):
946
- aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
947
- t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
948
- t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
949
-
950
- torch .manual_seed (12 )
951
- expected_output = t_ref (inpt )
952
-
953
- torch .manual_seed (12 )
954
- output = t (inpt )
955
-
956
- assert_equal (expected_output , output )
957
-
958
- @pytest .mark .parametrize (
959
- "interpolation" ,
960
- [
961
- v2_transforms .InterpolationMode .NEAREST ,
962
- v2_transforms .InterpolationMode .BILINEAR ,
963
- ],
964
- )
965
- def test_aa_jit (self , interpolation ):
966
- inpt = torch .randint (0 , 256 , size = (1 , 3 , 256 , 256 ), dtype = torch .uint8 )
967
- aa_policy = legacy_transforms .AutoAugmentPolicy ("imagenet" )
968
- t_ref = legacy_transforms .AutoAugment (aa_policy , interpolation = interpolation )
969
- t = v2_transforms .AutoAugment (aa_policy , interpolation = interpolation )
970
-
971
- tt_ref = torch .jit .script (t_ref )
972
- tt = torch .jit .script (t )
973
-
974
- torch .manual_seed (12 )
975
- expected_output = tt_ref (inpt )
976
-
977
- torch .manual_seed (12 )
978
- scripted_output = tt (inpt )
979
-
980
- assert_equal (scripted_output , expected_output )
981
-
982
-
983
708
def import_transforms_from_references (reference ):
984
709
HERE = Path (__file__ ).parent
985
710
PROJECT_ROOT = HERE .parent
0 commit comments