Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions test/test_prototype_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,12 @@ def test_common(self, transform, input):
(
transform,
[
dict(image=image, one_hot_label=one_hot_label)
for image, one_hot_label in itertools.product(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
dict(inpt=inpt, one_hot_label=one_hot_label)
for inpt, one_hot_label in itertools.product(
itertools.chain(
make_images(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
make_videos(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
),
make_one_hot_labels(extra_dims=BATCH_EXTRA_DIMS, dtypes=[torch.float]),
)
],
Expand Down
37 changes: 21 additions & 16 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,11 @@ def __init__(self, alpha: float, p: float = 0.5) -> None:
self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha]))

def forward(self, *inputs: Any) -> Any:
if not (has_any(inputs, features.Image, features.is_simple_tensor) and has_any(inputs, features.OneHotLabel)):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images and one-hot labels.")
if not (
has_any(inputs, features.Image, features.Video, features.is_simple_tensor)
and has_any(inputs, features.OneHotLabel)
):
raise TypeError(f"{type(self).__name__}() is only defined for tensor images/videos and one-hot labels.")
if has_any(inputs, PIL.Image.Image, features.BoundingBox, features.Mask, features.Label):
raise TypeError(
f"{type(self).__name__}() does not support PIL images, bounding boxes, masks and plain labels."
Expand All @@ -119,7 +122,7 @@ def _mixup_onehotlabel(self, inpt: features.OneHotLabel, lam: float) -> features
if inpt.ndim < 2:
raise ValueError("Need a batch of one hot labels")
output = inpt.clone()
output = output.roll(1, -2).mul_(1 - lam).add_(output.mul_(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam))
return features.OneHotLabel.wrap_like(inpt, output)


Expand All @@ -129,14 +132,15 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
lam = params["lam"]
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
output = inpt.clone()
output = output.roll(1, -4).mul_(1 - lam).add_(output.mul_(lam))
output = output.roll(1, 0).mul_(1.0 - lam).add_(output.mul_(lam)) # type: ignore[arg-type]

if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output)
if isinstance(inpt, (features.Image, features.Video)):
output = type(inpt).wrap_like(inpt, output)

return output
elif isinstance(inpt, features.OneHotLabel):
Expand Down Expand Up @@ -169,17 +173,18 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or features.is_simple_tensor(inpt):
if isinstance(inpt, (features.Image, features.Video)) or features.is_simple_tensor(inpt):
box = params["box"]
if inpt.ndim < 4:
raise ValueError("Need a batch of images")
expected_ndim = 5 if isinstance(inpt, features.Video) else 4
if inpt.ndim < expected_ndim:
raise ValueError("The transform expects a batched input")
x1, y1, x2, y2 = box
image_rolled = inpt.roll(1, -4)
rolled = inpt.roll(1, 0)
output = inpt.clone()
output[..., y1:y2, x1:x2] = image_rolled[..., y1:y2, x1:x2]
output[..., y1:y2, x1:x2] = rolled[..., y1:y2, x1:x2] # type: ignore[arg-type]

if isinstance(inpt, features.Image):
output = features.Image.wrap_like(inpt, output)
if isinstance(inpt, (features.Image, features.Video)):
output = inpt.wrap_like(inpt, output) # type: ignore[arg-type]

return output
elif isinstance(inpt, features.OneHotLabel):
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,8 +483,8 @@ def forward(self, *inputs: Any) -> Any:
augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

orig_dims = list(image_or_video.shape)
expected_dim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_dim - image_or_video.ndim, 0) + orig_dims)
expected_ndim = 5 if isinstance(orig_image_or_video, features.Video) else 4
batch = image_or_video.view([1] * max(expected_ndim - image_or_video.ndim, 0) + orig_dims)
batch_dims = [batch.size(0)] + [1] * (batch.ndim - 1)

# Sample the beta weights for combining the original and augmented image or video. To get Beta, we use a
Expand Down