Skip to content

port FiveCrop and TenCrop to prototype API #5513

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Mar 7, 2022
Merged
2 changes: 1 addition & 1 deletion torchvision/prototype/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from ._augment import RandomErasing, RandomMixup, RandomCutmix
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
from ._misc import Identity, Normalize, ToDtype, Lambda
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval
Expand Down
77 changes: 77 additions & 0 deletions torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import functools
import math
import warnings
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
Expand All @@ -6,6 +7,8 @@
import torch
from torchvision.prototype import features
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
from torchvision.prototype.utils._internal import apply_recursively
from torchvision.transforms.functional import pil_to_tensor
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int

from ._utils import query_image, get_image_dimensions, has_any
Expand Down Expand Up @@ -168,3 +171,77 @@ def forward(self, *inputs: Any) -> Any:
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)


class FiveCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]]) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.five_crop_image_tensor(input, self.size)
return F._FiveCropResult(*[features.Image.new_like(input, o) for o in output])
elif type(input) is torch.Tensor:
return F.five_crop_image_tensor(input, self.size)
elif isinstance(input, PIL.Image.Image):
return F.five_crop_image_pil(input, self.size)
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)


class TenCrop(Transform):
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
super().__init__()
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
self.vertical_flip = vertical_flip

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, features.Image):
output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip)
return F._TenCropResult(*[features.Image.new_like(input, o) for o in output])
elif type(input) is torch.Tensor:
return F.ten_crop_image_tensor(input, self.size)
elif isinstance(input, PIL.Image.Image):
return F.five_crop_image_pil(input, self.size)
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
if has_any(sample, features.BoundingBox, features.SegmentationMask):
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
return super().forward(sample)


class BatchMultiCrop(Transform):
_MULTI_CROP_TYPES = (F._FiveCropResult, F._TenCropResult)

def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
if isinstance(input, self._MULTI_CROP_TYPES):
crops = input
if isinstance(input[0], PIL.Image.Image):
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]

batch = torch.stack(crops)

if isinstance(input[0], features.Image):
batch = features.Image.new_like(input[0], batch)

return batch
else:
return input

def forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
return apply_recursively(
functools.partial(self._transform, params=self._get_params(sample)),
sample,
exclude_sequence_types=(str, *self._MULTI_CROP_TYPES),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this exclude here, because named tuples by default would be recognized as sequence and thus we would only get the individual elements rather than everything at once.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that's one more reason not to use named tuples.

)
6 changes: 6 additions & 0 deletions torchvision/prototype/transforms/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,12 @@
perspective_image_pil,
vertical_flip_image_tensor,
vertical_flip_image_pil,
_FiveCropResult,
five_crop_image_tensor,
five_crop_image_pil,
_TenCropResult,
ten_crop_image_tensor,
ten_crop_image_pil,
)
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot
95 changes: 94 additions & 1 deletion torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numbers
from typing import NamedTuple
from typing import Tuple, List, Optional, Sequence, Union

import PIL.Image
Expand All @@ -10,7 +11,6 @@

from ._meta import convert_bounding_box_format, get_dimensions_image_tensor, get_dimensions_image_pil


horizontal_flip_image_tensor = _FT.hflip
horizontal_flip_image_pil = _FP.hflip

Expand Down Expand Up @@ -314,3 +314,96 @@ def resized_crop_image_pil(
) -> PIL.Image.Image:
img = crop_image_pil(img, top, left, height, width)
return resize_image_pil(img, size, interpolation=interpolation)


class _FiveCropResult(NamedTuple):
top_left: torch.Tensor
top_right: torch.Tensor
bottom_left: torch.Tensor
bottom_right: torch.Tensor
center: torch.Tensor


def _parse_five_crop_size(size: List[int]) -> List[int]:
if isinstance(size, numbers.Number):
size = (int(size), int(size))
elif isinstance(size, (tuple, list)) and len(size) == 1:
size = (size[0], size[0]) # type: ignore[assignment]

if len(size) != 2:
raise ValueError("Please provide only two dimensions (h, w) for size.")

return size


def five_crop_image_tensor(img: torch.Tensor, size: List[int]) -> _FiveCropResult:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_tensor(img)

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))

tl = crop_image_tensor(img, 0, 0, crop_height, crop_width)
tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_tensor(img, [crop_height, crop_width])

return _FiveCropResult(tl, tr, bl, br, center)


def five_crop_image_pil(img: PIL.Image.Image, size: List[int]) -> _FiveCropResult:
crop_height, crop_width = _parse_five_crop_size(size)
_, image_height, image_width = get_dimensions_image_pil(img)

if crop_width > image_width or crop_height > image_height:
msg = "Requested crop size {} is bigger than input size {}"
raise ValueError(msg.format(size, (image_height, image_width)))

tl = crop_image_pil(img, 0, 0, crop_height, crop_width)
tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width)
bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width)
br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
center = center_crop_image_pil(img, [crop_height, crop_width])

return _FiveCropResult(tl, tr, bl, br, center)


class _TenCropResult(NamedTuple):
top_left: torch.Tensor
top_right: torch.Tensor
bottom_left: torch.Tensor
bottom_right: torch.Tensor
center: torch.Tensor
top_left_flip: torch.Tensor
top_right_flip: torch.Tensor
bottom_left_flip: torch.Tensor
bottom_right_flip: torch.Tensor
center_flip: torch.Tensor


def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> _TenCropResult:
tl, tr, bl, br, center = five_crop_image_tensor(img, size)

if vertical_flip:
img = vertical_flip_image_tensor(img)
else:
img = horizontal_flip_image_tensor(img)

tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size)

return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)


def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> _TenCropResult:
tl, tr, bl, br, center = five_crop_image_pil(img, size)

if vertical_flip:
img = vertical_flip_image_pil(img)
else:
img = horizontal_flip_image_pil(img)

tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)

return _TenCropResult(tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
40 changes: 35 additions & 5 deletions torchvision/prototype/utils/_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TypeVar,
Union,
Optional,
Type,
)

import numpy as np
Expand Down Expand Up @@ -301,13 +302,42 @@ def read(self, size: int = -1) -> bytes:
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes()


def apply_recursively(fn: Callable, obj: Any) -> Any:
def apply_recursively(
fn: Callable,
obj: Any,
*,
include_sequence_types: Collection[Type] = (collections.abc.Sequence,),
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop:
# "a" == "a"[0][0]...
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
return [apply_recursively(fn, item) for item in obj]
elif isinstance(obj, collections.abc.Mapping):
return {key: apply_recursively(fn, item) for key, item in obj.items()}
exclude_sequence_types: Collection[Type] = (str,),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this addition to be able to exclude named tuples as sequences in the BatchMultiCrop transform. My gut says we are going to need this fine grained control again for other transforms that are not yet ported / implemented. If that turns out not to be true, I'm happy to simply implement a custom solution only in the BatchMultiCrop transform given that we probably deprecate it anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should start from the assumption that this is not needed and add it later. Starting with a simple solution first is a good prior. Please simplify as much as possible.

include_mapping_types: Collection[Type] = (collections.abc.Mapping,),
exclude_mapping_types: Collection[Type] = (),
) -> Any:
if isinstance(obj, tuple(include_sequence_types)) and not isinstance(obj, tuple(exclude_sequence_types)):
return [
apply_recursively(
fn,
item,
include_sequence_types=include_sequence_types,
exclude_sequence_types=exclude_sequence_types,
include_mapping_types=include_mapping_types,
exclude_mapping_types=exclude_mapping_types,
)
for item in obj
]

if isinstance(obj, tuple(include_mapping_types)) and not isinstance(obj, tuple(exclude_mapping_types)):
return {
key: apply_recursively(
fn,
item,
include_sequence_types=include_sequence_types,
exclude_sequence_types=exclude_sequence_types,
include_mapping_types=include_mapping_types,
exclude_mapping_types=exclude_mapping_types,
)
for key, item in obj.items()
}
else:
return fn(obj)

Expand Down