Skip to content

Commit 12adc54

Browse files
authored
Add video support on MixUp and CutMix (#6733)
* Add video support on MixUp and CutMix * Switch back to roll * Fix tests and mypy * Another mypy fix
1 parent a3fe870 commit 12adc54

File tree

3 files changed

+29
-21
lines changed

3 files changed

+29
-21
lines changed

test/test_prototype_transforms.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,9 +112,12 @@ def test_common(self, transform, input):
112112
(
113113
transform,
114114
[
115-
dict(image=image, one_hot_label=one_hot_label)
116-
for image, one_hot_label in itertools.product(
117-
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
115+
dict(inpt=inpt, one_hot_label=one_hot_label)
116+
for inpt, one_hot_label in itertools.product(
117+
itertools.chain(
118+
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
119+
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
120+
),
118121
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
119122
)
120123
],

torchvision/prototype/transforms/_augment.py

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,11 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
107107
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))
108108

109109
def forward(self, *inputs: Any) -> Any:
110-
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
111-
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
110+
if not (
111+
has_any(inputs, features.Image, features.Video, features.is_simple_tensor)
112+
and has_any(inputs, features.OneHotLabel)
113+
):
114+
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
112115
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
113116
raise TypeError(
114117
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
@@ -119,7 +122,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features
119122
if inpt.ndim < 2:
120123
raise ValueError("Need a batch of one hot labels")
121124
output = inpt.clone()
122-
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
125+
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
123126
return features.OneHotLabel.wrap_like(inpt, output)
124127

125128

@@ -129,14 +132,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
129132

130133
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
131134
lam = params["lam"]
132-
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
133-
if inpt.ndim < 4:
134-
raise ValueError("Need a batch of images")
135+
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
136+
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
137+
if inpt.ndim < expected_ndim:
138+
raise ValueError("The transform expects a batched input")
135139
output = inpt.clone()
136-
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
140+
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
137141

138-
if isinstance(inpt, features.Image):
139-
output = features.Image.wrap_like(inpt, output)
142+
if isinstance(inpt, (features.Image, features.Video)):
143+
output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type]
140144

141145
return output
142146
elif isinstance(inpt, features.OneHotLabel):
@@ -169,17 +173,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
169173
return dict(box=box, lam_adjusted=lam_adjusted)
170174

171175
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
172-
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
176+
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
173177
box = params["box"]
174-
if inpt.ndim < 4:
175-
raise ValueError("Need a batch of images")
178+
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
179+
if inpt.ndim < expected_ndim:
180+
raise ValueError("The transform expects a batched input")
176181
x1, y1, x2, y2 = box
177-
image_rolled = inpt.roll(1, -4)
182+
rolled = inpt.roll(1, 0)
178183
output = inpt.clone()
179-
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
184+
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2]
180185

181-
if isinstance(inpt, features.Image):
182-
output = features.Image.wrap_like(inpt, output)
186+
if isinstance(inpt, (features.Image, features.Video)):
187+
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]
183188

184189
return output
185190
elif isinstance(inpt, features.OneHotLabel):

torchvision/prototype/transforms/_auto_augment.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -483,8 +483,8 @@ def forward(self, *inputs: Any) -> Any:
483483
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE
484484

485485
orig_dims = list(image_or_video.shape)
486-
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
487-
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
486+
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
487+
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
488488
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)
489489

490490
# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a

0 commit comments

Comments
 (0)