Skip to content

Cleanup of prototype transforms #6492

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 25, 2022
113 changes: 71 additions & 42 deletions torchvision/prototype/transforms/_auto_augment.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import math
import numbers
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, Type, TypeVar, Union

import PIL.Image
import torch

from torch.utils._pytree import tree_flatten, tree_unflatten
from torchvision.prototype import features
from torchvision.prototype.transforms import functional as F, Transform
from torchvision.transforms.autoaugment import AutoAugmentPolicy
from torchvision.transforms.functional import InterpolationMode, pil_to_tensor, to_pil_image

from ._utils import is_simple_tensor, query_chw
from ._utils import _isinstance, get_chw, is_simple_tensor

K = TypeVar("K")
V = TypeVar("V")
Expand All @@ -35,9 +36,31 @@ def _get_random_item(self, dct: Dict[K, V]) -> Tuple[K, V]:
key = keys[int(torch.randint(len(keys), ()))]
return key, dct[key]

def _get_params(self, sample: Any) -> Dict[str, Any]:
_, height, width = query_chw(sample)
return dict(height=height, width=width)
def _extract_image(
self,
sample: Any,
unsupported_types: Tuple[Type, ...] = (features.BoundingBox, features.SegmentationMask),
) -> Tuple[int, Union[PIL.Image.Image, torch.Tensor, features.Image]]:
sample_flat, _ = tree_flatten(sample)
images = []
for id, inpt in enumerate(sample_flat):
if _isinstance(inpt, (features.Image, PIL.Image.Image, is_simple_tensor)):
images.append((id, inpt))
elif isinstance(inpt, unsupported_types):
raise TypeError(f"Inputs of type {type(inpt).__name__} are not supported by {type(self).__name__}()")

if not images:
raise TypeError("Found no image in the sample.")
if len(images) > 1:
raise TypeError(
f"Auto augment transformations are only properly defined for a single image, but found {len(images)}."
)
return images[0]

def _put_into_sample(self, sample: Any, id: int, item: Any) -> Any:
sample_flat, spec = tree_flatten(sample)
sample_flat[id] = item
return tree_unflatten(sample_flat, spec)

def _apply_image_transform(
self,
Expand Down Expand Up @@ -242,34 +265,33 @@ def _get_policies(
else:
raise ValueError(f"The provided policy {policy} is not recognized.")

def _get_params(self, sample: Any) -> Dict[str, Any]:
params = super(AutoAugment, self)._get_params(sample)
params["policy"] = self._policies[int(torch.randint(len(self._policies), ()))]
return params
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rant: Not massive fan of using image naming here but we can change once we introduce videos.

num_channels, height, width = get_chw(image)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
policy = self._policies[int(torch.randint(len(self._policies), ()))]

for transform_id, probability, magnitude_idx in params["policy"]:
for transform_id, probability, magnitude_idx in policy:
if not torch.rand(()) <= probability:
continue

magnitudes_fn, signed = self._AUGMENTATION_SPACE[transform_id]

magnitudes = magnitudes_fn(10, params["height"], params["width"])
magnitudes = magnitudes_fn(10, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[magnitude_idx])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0

inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)

return inpt
return self._put_into_sample(sample, id, image)


class RandAugment(_AutoAugmentBase):
Expand Down Expand Up @@ -315,26 +337,28 @@ def __init__(
self.magnitude = magnitude
self.num_magnitude_bins = num_magnitude_bins

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)

for _ in range(self.num_ops):
transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0

inpt = self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)

return inpt
return self._put_into_sample(sample, id, image)


class TrivialAugmentWide(_AutoAugmentBase):
Expand Down Expand Up @@ -370,23 +394,26 @@ def __init__(
super().__init__(interpolation=interpolation, fill=fill)
self.num_magnitude_bins = num_magnitude_bins

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if not (isinstance(inpt, (features.Image, PIL.Image.Image)) or is_simple_tensor(inpt)):
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]

id, image = self._extract_image(sample)
num_channels, height, width = get_chw(image)

transform_id, (magnitudes_fn, signed) = self._get_random_item(self._AUGMENTATION_SPACE)

magnitudes = magnitudes_fn(self.num_magnitude_bins, params["height"], params["width"])
magnitudes = magnitudes_fn(self.num_magnitude_bins, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.num_magnitude_bins, ()))])
if signed and torch.rand(()) <= 0.5:
magnitude *= -1
else:
magnitude = 0.0

return self._apply_image_transform(
inpt, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
image = self._apply_image_transform(
image, transform_id, magnitude, interpolation=self.interpolation, fill=self.fill
)
return self._put_into_sample(sample, id, image)


class AugMix(_AutoAugmentBase):
Expand Down Expand Up @@ -438,13 +465,15 @@ def _sample_dirichlet(self, params: torch.Tensor) -> torch.Tensor:
# Must be on a separate method so that we can overwrite it in tests.
return torch._sample_dirichlet(params)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image) or is_simple_tensor(inpt):
image = inpt
elif isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(inpt)
else:
return inpt
def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
id, orig_image = self._extract_image(sample)
num_channels, height, width = get_chw(orig_image)

if isinstance(orig_image, torch.Tensor):
image = orig_image
else: # isinstance(inpt, PIL.Image.Image):
image = pil_to_tensor(orig_image)

augmentation_space = self._AUGMENTATION_SPACE if self.all_ops else self._PARTIAL_AUGMENTATION_SPACE

Expand All @@ -470,7 +499,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
for _ in range(depth):
transform_id, (magnitudes_fn, signed) = self._get_random_item(augmentation_space)

magnitudes = magnitudes_fn(self._PARAMETER_MAX, params["height"], params["width"])
magnitudes = magnitudes_fn(self._PARAMETER_MAX, height, width)
if magnitudes is not None:
magnitude = float(magnitudes[int(torch.randint(self.severity, ()))])
if signed and torch.rand(()) <= 0.5:
Expand All @@ -484,9 +513,9 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
mix.add_(combined_weights[:, i].view(batch_dims) * aug)
mix = mix.view(orig_dims).to(dtype=image.dtype)

if isinstance(inpt, features.Image):
mix = features.Image.new_like(inpt, mix)
elif isinstance(inpt, PIL.Image.Image):
if isinstance(orig_image, features.Image):
mix = features.Image.new_like(orig_image, mix)
elif isinstance(orig_image, PIL.Image.Image):
mix = to_pil_image(mix)

return mix
return self._put_into_sample(sample, id, mix)
21 changes: 12 additions & 9 deletions torchvision/prototype/transforms/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
import numpy as np
import PIL.Image
import torch
import torchvision.prototype.transforms.functional as F

from torchvision.prototype import features
from torchvision.prototype.features import ColorSpace
from torchvision.prototype.transforms import Transform
from torchvision.transforms import functional as _F
from typing_extensions import Literal

from ._transform import _RandomApplyTransform
from ._utils import is_simple_tensor
from ._utils import is_simple_tensor, query_chw


class ToTensor(Transform):
Expand Down Expand Up @@ -59,6 +58,8 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> PIL.Image:


class Grayscale(Transform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)

def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
deprecation_msg = (
f"The transform `Grayscale(num_output_channels={num_output_channels})` "
Expand All @@ -81,13 +82,12 @@ def __init__(self, num_output_channels: Literal[1, 3] = 1) -> None:
self.num_output_channels = num_output_channels

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
if self.num_output_channels == 3:
output = F.convert_color_space(inpt, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return output
return _F.rgb_to_grayscale(inpt, num_output_channels=self.num_output_channels)


class RandomGrayscale(_RandomApplyTransform):
_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)

def __init__(self, p: float = 0.1) -> None:
warnings.warn(
"The transform `RandomGrayscale(p=...)` is deprecated and will be removed in a future release. "
Expand All @@ -103,6 +103,9 @@ def __init__(self, p: float = 0.1) -> None:

super().__init__(p=p)

def _get_params(self, sample: Any) -> Dict[str, Any]:
num_input_channels, _, _ = query_chw(sample)
return dict(num_input_channels=num_input_channels)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
output = F.convert_color_space(inpt, color_space=ColorSpace.GRAY, old_color_space=ColorSpace.RGB)
return F.convert_color_space(output, color_space=ColorSpace.RGB, old_color_space=ColorSpace.GRAY)
return _F.rgb_to_grayscale(inpt, num_output_channels=params["num_input_channels"])
26 changes: 6 additions & 20 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,22 +156,14 @@ class FiveCrop(Transform):
torch.Size([5])
"""

_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)

def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
# TODO: returning a list is technically BC breaking since FiveCrop returned a tuple before. We switched to a
# list here to align it with TenCrop.
if isinstance(inpt, features.Image):
output = F.five_crop_image_tensor(inpt, self.size)
return tuple(features.Image.new_like(inpt, o) for o in output)
elif is_simple_tensor(inpt):
return F.five_crop_image_tensor(inpt, self.size)
elif isinstance(inpt, PIL.Image.Image):
return F.five_crop_image_pil(inpt, self.size)
else:
return inpt
return F.five_crop(inpt, self.size)

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
Expand All @@ -185,21 +177,15 @@ class TenCrop(Transform):
See :class:`~torchvision.prototype.transforms.FiveCrop` for an example.
"""

_transformed_types = (features.Image, PIL.Image.Image, is_simple_tensor)

def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if isinstance(inpt, features.Image):
output = F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
return [features.Image.new_like(inpt, o) for o in output]
elif is_simple_tensor(inpt):
return F.ten_crop_image_tensor(inpt, self.size, vertical_flip=self.vertical_flip)
elif isinstance(inpt, PIL.Image.Image):
return F.ten_crop_image_pil(inpt, self.size, vertical_flip=self.vertical_flip)
else:
return inpt
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
Expand Down
4 changes: 2 additions & 2 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class Lambda(Transform):
def __init__(self, fn: Callable[[Any], Any], *types: Type):
super().__init__()
self.fn = fn
self.types = types
self.types = types or (object,)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
if type(inpt) in self.types:
Expand Down Expand Up @@ -137,7 +137,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
class ToDtype(Lambda):
def __init__(self, dtype: torch.dtype, *types: Type) -> None:
self.dtype = dtype
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types)
super().__init__(functools.partial(torch.Tensor.to, dtype=dtype), *types or (torch.Tensor,))

def extra_repr(self) -> str:
return ", ".join([f"dtype={self.dtype}", f"types={[type.__name__ for type in self.types]}"])
Expand Down
2 changes: 2 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
elastic_image_tensor,
elastic_segmentation_mask,
elastic_transform,
five_crop,
five_crop_image_pil,
five_crop_image_tensor,
horizontal_flip,
Expand Down Expand Up @@ -97,6 +98,7 @@
rotate_image_pil,
rotate_image_tensor,
rotate_segmentation_mask,
ten_crop,
ten_crop_image_pil,
ten_crop_image_tensor,
vertical_flip,
Expand Down
21 changes: 21 additions & 0 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1078,6 +1078,17 @@ def five_crop_image_pil(
return tl, tr, bl, br, center


def five_crop(inpt: DType, size: List[int]) -> Tuple[DType, DType, DType, DType, DType]:
# TODO: consider breaking BC here to return List[DType] to align this op with `ten_crop`
if isinstance(inpt, torch.Tensor):
output = five_crop_image_tensor(inpt, size)
if isinstance(inpt, features.Image):
output = tuple(features.Image.new_like(inpt, item) for item in output) # type: ignore[assignment]
return output
else: # isinstance(inpt, PIL.Image.Image):
return five_crop_image_pil(inpt, size)


def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(img, size)

Expand All @@ -1102,3 +1113,13 @@ def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: boo
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)

return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]


def ten_crop(inpt: DType, size: List[int], *, vertical_flip: bool = False) -> List[DType]:
if isinstance(inpt, torch.Tensor):
output = ten_crop_image_tensor(inpt, size, vertical_flip=vertical_flip)
if isinstance(inpt, features.Image):
output = [features.Image.new_like(inpt, item) for item in output]
return output
else: # isinstance(inpt, PIL.Image.Image):
return ten_crop_image_pil(inpt, size, vertical_flip=vertical_flip)