1
1
import enum
2
2
import inspect
3
+ import random
4
+ from collections import defaultdict
3
5
from importlib .machinery import SourceFileLoader
4
6
from pathlib import Path
5
7
16
18
make_image ,
17
19
make_images ,
18
20
make_label ,
21
+ make_segmentation_mask ,
19
22
)
20
23
from torchvision import transforms as legacy_transforms
21
24
from torchvision ._utils import sequence_to_str
22
25
from torchvision .prototype import features , transforms as prototype_transforms
26
+ from torchvision .prototype .transforms import functional as F
27
+ from torchvision .prototype .transforms ._utils import query_chw
23
28
from torchvision .prototype .transforms .functional import to_image_pil
24
29
25
-
26
30
DEFAULT_MAKE_IMAGES_KWARGS = dict (color_spaces = [features .ColorSpace .RGB ], extra_dims = [(4 ,)])
27
31
28
32
@@ -852,10 +856,12 @@ def test_aa(self, inpt, interpolation):
852
856
assert_equal (expected_output , output )
853
857
854
858
855
- # Import reference detection transforms here for consistency checks
856
- # torchvision/references/detection/transforms.py
857
- ref_det_filepath = Path (__file__ ).parent .parent / "references" / "detection" / "transforms.py"
858
- det_transforms = SourceFileLoader (ref_det_filepath .stem , ref_det_filepath .as_posix ()).load_module ()
859
+ def import_transforms_from_references (reference ):
860
+ ref_det_filepath = Path (__file__ ).parent .parent / "references" / reference / "transforms.py"
861
+ return SourceFileLoader (ref_det_filepath .stem , ref_det_filepath .as_posix ()).load_module ()
862
+
863
+
864
+ det_transforms = import_transforms_from_references ("detection" )
859
865
860
866
861
867
class TestRefDetTransforms :
@@ -873,7 +879,7 @@ def make_datapoints(self, with_mask=True):
873
879
874
880
yield (pil_image , target )
875
881
876
- tensor_image = torch .randint ( 0 , 256 , size = ( 3 , * size ), dtype = torch . uint8 )
882
+ tensor_image = torch .Tensor ( make_image ( size = size , color_space = features . ColorSpace . RGB ) )
877
883
target = {
878
884
"boxes" : make_bounding_box (image_size = size , format = "XYXY" , extra_dims = (num_objects ,), dtype = torch .float ),
879
885
"labels" : make_label (extra_dims = (num_objects ,), categories = 80 ),
@@ -883,7 +889,7 @@ def make_datapoints(self, with_mask=True):
883
889
884
890
yield (tensor_image , target )
885
891
886
- feature_image = features . Image ( torch . randint ( 0 , 256 , size = ( 3 , * size ), dtype = torch . uint8 ) )
892
+ feature_image = make_image ( size = size , color_space = features . ColorSpace . RGB )
887
893
target = {
888
894
"boxes" : make_bounding_box (image_size = size , format = "XYXY" , extra_dims = (num_objects ,), dtype = torch .float ),
889
895
"labels" : make_label (extra_dims = (num_objects ,), categories = 80 ),
@@ -927,3 +933,165 @@ def test_transform(self, t_ref, t, data_kwargs):
927
933
expected_output = t_ref (* dp )
928
934
929
935
assert_equal (expected_output , output )
936
+
937
+
938
+ seg_transforms = import_transforms_from_references ("segmentation" )
939
+
940
+
941
+ # We need this transform for two reasons:
942
+ # 1. transforms.RandomCrop uses a different scheme to pad images and masks of insufficient size than its name
943
+ # counterpart in the detection references. Thus, we cannot use it with `pad_if_needed=True`
944
+ # 2. transforms.Pad only supports a fixed padding, but the segmentation datasets don't have a fixed image size.
945
+ class PadIfSmaller (prototype_transforms .Transform ):
946
+ def __init__ (self , size , fill = 0 ):
947
+ super ().__init__ ()
948
+ self .size = size
949
+ self .fill = prototype_transforms ._geometry ._setup_fill_arg (fill )
950
+
951
+ def _get_params (self , sample ):
952
+ _ , height , width = query_chw (sample )
953
+ padding = [0 , 0 , max (self .size - width , 0 ), max (self .size - height , 0 )]
954
+ needs_padding = any (padding )
955
+ return dict (padding = padding , needs_padding = needs_padding )
956
+
957
+ def _transform (self , inpt , params ):
958
+ if not params ["needs_padding" ]:
959
+ return inpt
960
+
961
+ fill = self .fill [type (inpt )]
962
+ fill = F ._geometry ._convert_fill_arg (fill )
963
+
964
+ return F .pad (inpt , padding = params ["padding" ], fill = fill )
965
+
966
+
967
+ class TestRefSegTransforms :
968
+ def make_datapoints (self , supports_pil = True , image_dtype = torch .uint8 ):
969
+ size = (256 , 640 )
970
+ num_categories = 21
971
+
972
+ conv_fns = []
973
+ if supports_pil :
974
+ conv_fns .append (to_image_pil )
975
+ conv_fns .extend ([torch .Tensor , lambda x : x ])
976
+
977
+ for conv_fn in conv_fns :
978
+ feature_image = make_image (size = size , color_space = features .ColorSpace .RGB , dtype = image_dtype )
979
+ feature_mask = make_segmentation_mask (size = size , num_categories = num_categories , dtype = torch .uint8 )
980
+
981
+ dp = (conv_fn (feature_image ), feature_mask )
982
+ dp_ref = (
983
+ to_image_pil (feature_image ) if supports_pil else torch .Tensor (feature_image ),
984
+ to_image_pil (feature_mask ),
985
+ )
986
+
987
+ yield dp , dp_ref
988
+
989
+ def set_seed (self , seed = 12 ):
990
+ torch .manual_seed (seed )
991
+ random .seed (seed )
992
+
993
+ def check (self , t , t_ref , data_kwargs = None ):
994
+ for dp , dp_ref in self .make_datapoints (** data_kwargs or dict ()):
995
+
996
+ self .set_seed ()
997
+ output = t (dp )
998
+
999
+ self .set_seed ()
1000
+ expected_output = t_ref (* dp_ref )
1001
+
1002
+ assert_equal (output , expected_output )
1003
+
1004
+ @pytest .mark .parametrize (
1005
+ ("t_ref" , "t" , "data_kwargs" ),
1006
+ [
1007
+ (
1008
+ seg_transforms .RandomHorizontalFlip (flip_prob = 1.0 ),
1009
+ prototype_transforms .RandomHorizontalFlip (p = 1.0 ),
1010
+ dict (),
1011
+ ),
1012
+ (
1013
+ seg_transforms .RandomHorizontalFlip (flip_prob = 0.0 ),
1014
+ prototype_transforms .RandomHorizontalFlip (p = 0.0 ),
1015
+ dict (),
1016
+ ),
1017
+ (
1018
+ seg_transforms .RandomCrop (size = 480 ),
1019
+ prototype_transforms .Compose (
1020
+ [
1021
+ PadIfSmaller (size = 480 , fill = defaultdict (lambda : 0 , {features .Mask : 255 })),
1022
+ prototype_transforms .RandomCrop (size = 480 ),
1023
+ ]
1024
+ ),
1025
+ dict (),
1026
+ ),
1027
+ (
1028
+ seg_transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )),
1029
+ prototype_transforms .Normalize (mean = (0.485 , 0.456 , 0.406 ), std = (0.229 , 0.224 , 0.225 )),
1030
+ dict (supports_pil = False , image_dtype = torch .float ),
1031
+ ),
1032
+ ],
1033
+ )
1034
+ def test_common (self , t_ref , t , data_kwargs ):
1035
+ self .check (t , t_ref , data_kwargs )
1036
+
1037
+ def check_resize (self , mocker , t_ref , t ):
1038
+ mock = mocker .patch ("torchvision.prototype.transforms._geometry.F.resize" )
1039
+ mock_ref = mocker .patch ("torchvision.transforms.functional.resize" )
1040
+
1041
+ for dp , dp_ref in self .make_datapoints ():
1042
+ mock .reset_mock ()
1043
+ mock_ref .reset_mock ()
1044
+
1045
+ self .set_seed ()
1046
+ t (dp )
1047
+ assert mock .call_count == 2
1048
+ assert all (
1049
+ actual is expected
1050
+ for actual , expected in zip ([call_args [0 ][0 ] for call_args in mock .call_args_list ], dp )
1051
+ )
1052
+
1053
+ self .set_seed ()
1054
+ t_ref (* dp_ref )
1055
+ assert mock_ref .call_count == 2
1056
+ assert all (
1057
+ actual is expected
1058
+ for actual , expected in zip ([call_args [0 ][0 ] for call_args in mock_ref .call_args_list ], dp_ref )
1059
+ )
1060
+
1061
+ for args_kwargs , args_kwargs_ref in zip (mock .call_args_list , mock_ref .call_args_list ):
1062
+ assert args_kwargs [0 ][1 ] == [args_kwargs_ref [0 ][1 ]]
1063
+
1064
+ def test_random_resize_train (self , mocker ):
1065
+ base_size = 520
1066
+ min_size = base_size // 2
1067
+ max_size = base_size * 2
1068
+
1069
+ randint = torch .randint
1070
+
1071
+ def patched_randint (a , b , * other_args , ** kwargs ):
1072
+ if kwargs or len (other_args ) > 1 or other_args [0 ] != ():
1073
+ return randint (a , b , * other_args , ** kwargs )
1074
+
1075
+ return random .randint (a , b )
1076
+
1077
+ # We are patching torch.randint -> random.randint here, because we can't patch the modules that are not imported
1078
+ # normally
1079
+ t = prototype_transforms .RandomResize (min_size = min_size , max_size = max_size , antialias = True )
1080
+ mocker .patch (
1081
+ "torchvision.prototype.transforms._geometry.torch.randint" ,
1082
+ new = patched_randint ,
1083
+ )
1084
+
1085
+ t_ref = seg_transforms .RandomResize (min_size = min_size , max_size = max_size )
1086
+
1087
+ self .check_resize (mocker , t_ref , t )
1088
+
1089
+ def test_random_resize_eval (self , mocker ):
1090
+ torch .manual_seed (0 )
1091
+ base_size = 520
1092
+
1093
+ t = prototype_transforms .Resize (size = base_size , antialias = True )
1094
+
1095
+ t_ref = seg_transforms .RandomResize (min_size = base_size , max_size = base_size )
1096
+
1097
+ self .check_resize (mocker , t_ref , t )
0 commit comments