diff --git a/test/test_prototype_transforms.py b/test/test_prototype_transforms.py index 4037a746703..d7a41e7c12c 100644 --- a/test/test_prototype_transforms.py +++ b/test/test_prototype_transforms.py @@ -112,9 +112,12 @@ def test_common(self, transform, input): ( transform, [ - dict(image=image, one_hot_label=one_hot_label) - for image, one_hot_label in itertools.product( - make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), + dict(inpt=inpt, one_hot_label=one_hot_label) + for inpt, one_hot_label in itertools.product( + itertools.chain( + make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), + make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), + ), make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]), ) ], diff --git a/torchvision/prototype/transforms/_augment.py b/torchvision/prototype/transforms/_augment.py index 7b2dca8a601..4bfb5c9ed1e 100644 --- a/torchvision/prototype/transforms/_augment.py +++ b/torchvision/prototype/transforms/_augment.py @@ -107,8 +107,11 @@ def __init__(self, alpha: float, p: float = 0.5) -> None: self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) def forward(self, *inputs: Any) -> Any: - if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)): - raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.") + if not ( + has_any(inputs, features.Image, features.Video, features.is_simple_tensor) + and has_any(inputs, features.OneHotLabel) + ): + raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.") if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label): raise TypeError( f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels." @@ -119,7 +122,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features if inpt.ndim < 2: raise ValueError("Need a batch of one hot labels") output = inpt.clone() - output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam)) + output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) return features.OneHotLabel.wrap_like(inpt, output) @@ -129,14 +132,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: lam = params["lam"] - if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt): - if inpt.ndim < 4: - raise ValueError("Need a batch of images") + if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): + expected_ndim = 5 if isinstance(inpt, features.Video) else 4 + if inpt.ndim < expected_ndim: + raise ValueError("The transform expects a batched input") output = inpt.clone() - output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam)) + output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) - if isinstance(inpt, features.Image): - output = features.Image.wrap_like(inpt, output) + if isinstance(inpt, (features.Image, features.Video)): + output = type(inpt).wrap_like(inpt, output) # type: ignore[arg-type] return output elif isinstance(inpt, features.OneHotLabel): @@ -169,17 +173,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]: return dict(box=box, lam_adjusted=lam_adjusted) def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: - if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt): + if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt): box = params["box"] - if inpt.ndim < 4: - raise ValueError("Need a batch of images") + expected_ndim = 5 if isinstance(inpt, features.Video) else 4 + if inpt.ndim < expected_ndim: + raise ValueError("The transform expects a batched input") x1, y1, x2, y2 = box - image_rolled = inpt.roll(1, -4) + rolled = inpt.roll(1, 0) output = inpt.clone() - output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2] + output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] - if isinstance(inpt, features.Image): - output = features.Image.wrap_like(inpt, output) + if isinstance(inpt, (features.Image, features.Video)): + output = inpt.wrap_like(inpt, output) # type: ignore[arg-type] return output elif isinstance(inpt, features.OneHotLabel): diff --git a/torchvision/prototype/transforms/_auto_augment.py b/torchvision/prototype/transforms/_auto_augment.py index d078cb2d1cb..b35b5529b18 100644 --- a/torchvision/prototype/transforms/_auto_augment.py +++ b/torchvision/prototype/transforms/_auto_augment.py @@ -483,8 +483,8 @@ def forward(self, *inputs: Any) -> Any: augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE orig_dims = list(image_or_video.shape) - expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4 - batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims) + expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4 + batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims) batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1) # Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a