@@ -107,8 +107,11 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
107
107
self ._dist = torch .distributions .Beta (torch .tensor ([alpha ]), torch .tensor ([alpha ]))
108
108
109
109
def forward (self , * inputs : Any ) -> Any :
110
- if not (has_any (inputs , features .Image , features .is_simple_tensor ) and has_any (inputs , features .OneHotLabel )):
111
- raise TypeError (f"{ type (self ).__name__ } () is only defined for tensor images and one-hot labels." )
110
+ if not (
111
+ has_any (inputs , features .Image , features .Video , features .is_simple_tensor )
112
+ and has_any (inputs , features .OneHotLabel )
113
+ ):
114
+ raise TypeError (f"{ type (self ).__name__ } () is only defined for tensor images/videos and one-hot labels." )
112
115
if has_any (inputs , PIL .Image .Image , features .BoundingBox , features .Mask , features .Label ):
113
116
raise TypeError (
114
117
f"{ type (self ).__name__ } () does not support PIL images, bounding boxes, masks and plain labels."
@@ -119,7 +122,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features
119
122
if inpt .ndim < 2 :
120
123
raise ValueError ("Need a batch of one hot labels" )
121
124
output = inpt .clone ()
122
- output = output .roll (1 , - 2 ).mul_ (1 - lam ).add_ (output .mul_ (lam ))
125
+ output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
123
126
return features .OneHotLabel .wrap_like (inpt , output )
124
127
125
128
@@ -129,14 +132,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
129
132
130
133
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
131
134
lam = params ["lam" ]
132
- if isinstance (inpt , features .Image ) or features .is_simple_tensor (inpt ):
133
- if inpt .ndim < 4 :
134
- raise ValueError ("Need a batch of images" )
135
+ if isinstance (inpt , (features .Image , features .Video )) or features .is_simple_tensor (inpt ):
136
+ expected_ndim = 5 if isinstance (inpt , features .Video ) else 4
137
+ if inpt .ndim < expected_ndim :
138
+ raise ValueError ("The transform expects a batched input" )
135
139
output = inpt .clone ()
136
- output = output .roll (1 , - 4 ).mul_ (1 - lam ).add_ (output .mul_ (lam ))
140
+ output = output .roll (1 , 0 ).mul_ (1.0 - lam ).add_ (output .mul_ (lam ))
137
141
138
- if isinstance (inpt , features .Image ):
139
- output = features . Image . wrap_like (inpt , output )
142
+ if isinstance (inpt , ( features .Image , features . Video ) ):
143
+ output = type ( inpt ). wrap_like (inpt , output ) # type: ignore[arg-type]
140
144
141
145
return output
142
146
elif isinstance (inpt , features .OneHotLabel ):
@@ -169,17 +173,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
169
173
return dict (box = box , lam_adjusted = lam_adjusted )
170
174
171
175
def _transform (self , inpt : Any , params : Dict [str , Any ]) -> Any :
172
- if isinstance (inpt , features .Image ) or features .is_simple_tensor (inpt ):
176
+ if isinstance (inpt , ( features .Image , features . Video ) ) or features .is_simple_tensor (inpt ):
173
177
box = params ["box" ]
174
- if inpt .ndim < 4 :
175
- raise ValueError ("Need a batch of images" )
178
+ expected_ndim = 5 if isinstance (inpt , features .Video ) else 4
179
+ if inpt .ndim < expected_ndim :
180
+ raise ValueError ("The transform expects a batched input" )
176
181
x1 , y1 , x2 , y2 = box
177
- image_rolled = inpt .roll (1 , - 4 )
182
+ rolled = inpt .roll (1 , 0 )
178
183
output = inpt .clone ()
179
- output [..., y1 :y2 , x1 :x2 ] = image_rolled [..., y1 :y2 , x1 :x2 ]
184
+ output [..., y1 :y2 , x1 :x2 ] = rolled [..., y1 :y2 , x1 :x2 ]
180
185
181
- if isinstance (inpt , features .Image ):
182
- output = features . Image . wrap_like (inpt , output )
186
+ if isinstance (inpt , ( features .Image , features . Video ) ):
187
+ output = inpt . wrap_like (inpt , output ) # type: ignore[arg-type]
183
188
184
189
return output
185
190
elif isinstance (inpt , features .OneHotLabel ):
0 commit comments