77from torchvision .prototype import features
88from torchvision .prototype .transforms import Transform , functional as F
99
10- from ._utils import query_image , get_image_dimensions
10+ from ._utils import query_image , get_image_dimensions , has_all , has_any
1111
1212
1313class RandomErasing (Transform ):
@@ -33,7 +33,6 @@ def __init__(
3333 raise ValueError ("Scale should be between 0 and 1" )
3434 if p < 0 or p > 1 :
3535 raise ValueError ("Random erasing probability should be between 0 and 1" )
36- # TODO: deprecate p in favor of wrapping the transform in a RandomApply
3736 self .p = p
3837 self .scale = scale
3938 self .ratio = ratio
@@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
8887 return dict (zip ("ijhwv" , (i , j , h , w , v )))
8988
9089 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
91- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
92- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
93- elif isinstance (input , features .Image ):
90+ if isinstance (input , features .Image ):
9491 output = F .erase_image_tensor (input , ** params )
9592 return features .Image .new_like (input , output )
9693 elif isinstance (input , torch .Tensor ):
@@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
9996 return input
10097
10198 def forward (self , * inputs : Any ) -> Any :
102- if torch .rand (1 ) >= self .p :
103- return inputs if len (inputs ) > 1 else inputs [0 ]
99+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
100+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
101+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
102+ elif torch .rand (1 ) >= self .p :
103+ return sample
104104
105- return super ().forward (* inputs )
105+ return super ().forward (sample )
106106
107107
108108class RandomMixup (Transform ):
@@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
115115 return dict (lam = float (self ._dist .sample (())))
116116
117117 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
118- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
119- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
120- elif isinstance (input , features .Image ):
118+ if isinstance (input , features .Image ):
121119 output = F .mixup_image_tensor (input , ** params )
122120 return features .Image .new_like (input , output )
123121 elif isinstance (input , features .OneHotLabel ):
@@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
126124 else :
127125 return input
128126
127+ def forward (self , * inputs : Any ) -> Any :
128+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
129+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
130+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
131+ elif not has_all (sample , features .Image , features .OneHotLabel ):
132+ raise TypeError (f"{ type (self ).__name__ } () is only defined for Image's *and* OneHotLabel's." )
133+ return super ().forward (sample )
134+
129135
130136class RandomCutmix (Transform ):
131137 def __init__ (self , * , alpha : float ) -> None :
@@ -157,13 +163,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
157163 return dict (box = box , lam_adjusted = lam_adjusted )
158164
159165 def _transform (self , input : Any , params : Dict [str , Any ]) -> Any :
160- if isinstance (input , (features .BoundingBox , features .SegmentationMask )):
161- raise TypeError (f"{ type (input ).__name__ } 's are not supported by { type (self ).__name__ } ()" )
162- elif isinstance (input , features .Image ):
166+ if isinstance (input , features .Image ):
163167 output = F .cutmix_image_tensor (input , box = params ["box" ])
164168 return features .Image .new_like (input , output )
165169 elif isinstance (input , features .OneHotLabel ):
166170 output = F .cutmix_one_hot_label (input , lam_adjusted = params ["lam_adjusted" ])
167171 return features .OneHotLabel .new_like (input , output )
168172 else :
169173 return input
174+
175+ def forward (self , * inputs : Any ) -> Any :
176+ sample = inputs if len (inputs ) > 1 else inputs [0 ]
177+ if has_any (sample , features .BoundingBox , features .SegmentationMask ):
178+ raise TypeError (f"BoundingBox'es and SegmentationMask's are not supported by { type (self ).__name__ } ()" )
179+ elif not has_all (sample , features .Image , features .OneHotLabel ):
180+ raise TypeError (f"{ type (self ).__name__ } () is only defined for Image's *and* OneHotLabel's." )
181+ return super ().forward (sample )
0 commit comments