@@ -107,8 +107,11 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
107107 self ._dist = torch .distributions .Beta (torch .tensor ([alpha ]), torch .tensor ([alpha ]))
108108
109109 def forward (self , * inputs : Any ) -> Any :
110- if not (has_any (inputs , features .Image , features .is_simple_tensor ) and has_any (inputs , features .OneHotLabel )):
111- raise TypeError (f"{ type (self ).__name__ } () is only defined for tensor images and one-hot labels." )
110+ if not (
111+ has_any (inputs , features .Image , features .Video , features .is_simple_tensor )
112+ and has_any (inputs , features .OneHotLabel )
113+ ):
114+ raise TypeError (f"{ type (self ).__name__ } () is only defined for tensor images/videos and one-hot labels." )
112115 if has_any (inputs , PIL .Image .Image , features .BoundingBox , features .Mask , features .Label ):
113116 raise TypeError (
114117 f"{ type (self ).__name__ } () does not support PIL images, bounding boxes, masks and plain labels."
@@ -119,7 +122,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features
119122 if inpt .ndim < 2 :
120123 raise ValueError ("Need a batch of one hot labels" )
121124 output = inpt .clone ()
122- output = output .roll (1 , - 2 ).mul_ (1 - lam ).add_ (output .mul_ (lam ))
125+ output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
123126 return features .OneHotLabel .wrap_like (inpt , output )
124127
125128
@@ -129,14 +132,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
129132
130133 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
131134 lam = params ["lam" ]
132- if isinstance (inpt , features .Image ) or features .is_simple_tensor (inpt ):
133- if inpt .ndim < 4 :
134- raise ValueError ("Need a batch of images" )
135+ if isinstance (inpt , (features .Image , features .Video )) or features .is_simple_tensor (inpt ):
136+ expected_ndim = 5 if isinstance (inpt , features .Video ) else 4
137+ if inpt .ndim < expected_ndim :
138+ raise ValueError ("The transform expects a batched input" )
135139 output = inpt .clone ()
136- output = output .roll (1 , - 4 ).mul_ (1 - lam ).add_ (output .mul_ (lam ))
140+ output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
137141
138- if isinstance (inpt , features .Image ):
139- output = features . Image . wrap_like (inpt , output )
142+ if isinstance (inpt , ( features .Image , features . Video ) ):
143+ output = type ( inpt ). wrap_like (inpt , output ) # type: ignore[arg-type]
140144
141145 return output
142146 elif isinstance (inpt , features .OneHotLabel ):
@@ -169,17 +173,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
169173 return dict (box = box , lam_adjusted = lam_adjusted )
170174
171175 def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
172- if isinstance (inpt , features .Image ) or features .is_simple_tensor (inpt ):
176+ if isinstance (inpt , ( features .Image , features . Video ) ) or features .is_simple_tensor (inpt ):
173177 box = params ["box" ]
174- if inpt .ndim < 4 :
175- raise ValueError ("Need a batch of images" )
178+ expected_ndim = 5 if isinstance (inpt , features .Video ) else 4
179+ if inpt .ndim < expected_ndim :
180+ raise ValueError ("The transform expects a batched input" )
176181 x1 , y1 , x2 , y2 = box
177- image_rolled = inpt .roll (1 , - 4 )
182+ rolled = inpt .roll (1 , 0 )
178183 output = inpt .clone ()
179- output [..., y1 :y2 , x1 :x2 ] = image_rolled [..., y1 :y2 , x1 :x2 ]
184+ output [..., y1 :y2 , x1 :x2 ] = rolled [..., y1 :y2 , x1 :x2 ]
180185
181- if isinstance (inpt , features .Image ):
182- output = features . Image . wrap_like (inpt , output )
186+ if isinstance (inpt , ( features .Image , features . Video ) ):
187+ output = inpt . wrap_like (inpt , output ) # type: ignore[arg-type]
183188
184189 return output
185190 elif isinstance (inpt , features .OneHotLabel ):
0 commit comments