7
7
from torchvision .prototype import features
8
8
from torchvision .prototype .transforms import Transform , functional as F
9
9
10
- from ._utils import query_image , get_image_dimensions
10
+ from ._utils import query_image , get_image_dimensions , has_all , has_any
11
11
12
12
13
13
class RandomErasing (Transform ):
@@ -33,7 +33,6 @@ def __init__(
33
33
raise ValueError ("Scale should be between 0 and 1" )
34
34
if p < 0 or p > 1 :
35
35
raise ValueError ("Random erasing probability should be between 0 and 1" )
36
- # TODO: deprecate p in favor of wrapping the transform in a RandomApply
37
36
self .p = p
38
37
self .scale = scale
39
38
self .ratio = ratio
@@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
88
87
return dict (zip ("ijhwv" , (i , j , h , w , v )))
89
88
90
89
def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
91
- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
92
- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
93
- elif isinstance (input , features .Image ):
90
+ if isinstance (input , features .Image ):
94
91
output = F .erase_image_tensor (input , ** params )
95
92
return features .Image .new_like (input , output )
96
93
elif isinstance (input , torch .Tensor ):
@@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
99
96
return input
100
97
101
98
def forward (self , * inputs : Any ) -> Any :
102
- if torch .rand (1 ) >= self .p :
103
- return inputs if len (inputs ) > 1 else inputs [0 ]
99
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
100
+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
101
+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
102
+ elif torch .rand (1 ) >= self .p :
103
+ return sample
104
104
105
- return super ().forward (* inputs )
105
+ return super ().forward (sample )
106
106
107
107
108
108
class RandomMixup (Transform ):
@@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
115
115
return dict (lam = float (self ._dist .sample (())))
116
116
117
117
def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
118
- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
119
- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
120
- elif isinstance (input , features .Image ):
118
+ if isinstance (input , features .Image ):
121
119
output = F .mixup_image_tensor (input , ** params )
122
120
return features .Image .new_like (input , output )
123
121
elif isinstance (input , features .OneHotLabel ):
@@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
126
124
else :
127
125
return input
128
126
127
+ def forward (self , * inputs : Any ) -> Any :
128
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
129
+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
130
+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
131
+ elif not has_all (sample , features .Image , features .OneHotLabel ):
132
+ raise TypeError (f"{ type (self ).__name__ } () is only defined for Image's *and* OneHotLabel's." )
133
+ return super ().forward (sample )
134
+
129
135
130
136
class RandomCutmix (Transform ):
131
137
def __init__ (self , * , alpha : float ) -> None :
@@ -157,13 +163,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
157
163
return dict (box = box , lam_adjusted = lam_adjusted )
158
164
159
165
def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
160
- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
161
- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
162
- elif isinstance (input , features .Image ):
166
+ if isinstance (input , features .Image ):
163
167
output = F .cutmix_image_tensor (input , box = params ["box" ])
164
168
return features .Image .new_like (input , output )
165
169
elif isinstance (input , features .OneHotLabel ):
166
170
output = F .cutmix_one_hot_label (input , lam_adjusted = params ["lam_adjusted" ])
167
171
return features .OneHotLabel .new_like (input , output )
168
172
else :
169
173
return input
174
+
175
+ def forward (self , * inputs : Any ) -> Any :
176
+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
177
+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
178
+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
179
+ elif not has_all (sample , features .Image , features .OneHotLabel ):
180
+ raise TypeError (f"{ type (self ).__name__ } () is only defined for Image's *and* OneHotLabel's." )
181
+ return super ().forward (sample )
0 commit comments