Skip to content

Prototype transforms cleanup #5504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 1, 2022
40 changes: 26 additions & 14 deletions torchvision/prototype/transforms/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, functional as F

from ._utils import query_image, get_image_dimensions
from ._utils import query_image, get_image_dimensions, has_all, has_any


class RandomErasing(Transform):
Expand All @@ -33,7 +33,6 @@ def __init__(
raise ValueError("Scale should be between 0 and 1")
if p < 0 or p > 1:
raise ValueError("Random erasing probability should be between 0 and 1")
# TODO: deprecate p in favor of wrapping the transform in a RandomApply
self.p = p
self.scale = scale
self.ratio = ratio
Expand Down Expand Up @@ -88,9 +87,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(zip("ijhwv", (i, j, h, w, v)))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.erase_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, torch.Tensor):
Expand All @@ -99,10 +96,13 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
return input

def forward(self, *inputs: Any) -> Any:
if torch.rand(1) >= self.p:
return inputs if len(inputs) > 1 else inputs[0]
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif torch.rand(1) >= self.p:
return sample

return super().forward(*inputs)
return super().forward(sample)


class RandomMixup(Transform):
Expand All @@ -115,9 +115,7 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(lam=float(self._dist.sample(())))

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.mixup_image_tensor(input, **params)
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
Expand All @@ -126,6 +124,14 @@ def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)


class RandomCutmix(Transform):
def __init__(self, *, alpha: float) -> None:
Expand Down Expand Up @@ -157,13 +163,19 @@ def _get_params(self, sample: Any) -> Dict[str, Any]:
return dict(box=box, lam_adjusted=lam_adjusted)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, (features.BoundingBox, features.SegmentationMask)):
raise TypeError(f"{type(input).__name__}'s are not supported by {type(self).__name__}()")
elif isinstance(input, features.Image):
if isinstance(input, features.Image):
output = F.cutmix_image_tensor(input, box=params["box"])
return features.Image.new_like(input, output)
elif isinstance(input, features.OneHotLabel):
output = F.cutmix_one_hot_label(input, lam_adjusted=params["lam_adjusted"])
return features.OneHotLabel.new_like(input, output)
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
elif not has_all(sample, features.Image, features.OneHotLabel):
raise TypeError(f"{type(self).__name__}() is only defined for Image's *and* OneHotLabel's.")
return super().forward(sample)
Loading