diff --git a/torchvision/prototype/transforms/__init__.py b/torchvision/prototype/transforms/__init__.py index 98ad7ae0d74..4641cc5ab86 100644 --- a/torchvision/prototype/transforms/__init__.py +++ b/torchvision/prototype/transforms/__init__.py @@ -7,7 +7,7 @@ from ._augment import RandomErasing, RandomMixup, RandomCutmix from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix from ._container import Compose, RandomApply, RandomChoice, RandomOrder -from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop +from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace from ._misc import Identity, Normalize, ToDtype, Lambda from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index 44e31dee856..e04e9f819f3 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -1,3 +1,4 @@ +import collections.abc import math import warnings from typing import Any, Dict, List, Union, Sequence, Tuple, cast @@ -6,6 +7,7 @@ import torch from torchvision.prototype import features from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F +from torchvision.transforms.functional import pil_to_tensor from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor @@ -168,3 +170,89 @@ def forward(self, *inputs: Any) -> Any: if has_any(sample, features.BoundingBox, features.SegmentationMask): raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") return super().forward(sample) + + +class MultiCropResult(list): + """Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`. + + Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and + `:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by + :class:`~torchvision.prototype.transforms.BatchMultiCrop`. + """ + + pass + + +class FiveCrop(Transform): + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.five_crop_image_tensor(input, self.size) + return MultiCropResult(features.Image.new_like(input, o) for o in output) + elif is_simple_tensor(input): + return MultiCropResult(F.five_crop_image_tensor(input, self.size)) + elif isinstance(input, PIL.Image.Image): + return MultiCropResult(F.five_crop_image_pil(input, self.size)) + else: + return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + + +class TenCrop(Transform): + def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: + super().__init__() + self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") + self.vertical_flip = vertical_flip + + def _transform(self, input: Any, params: Dict[str, Any]) -> Any: + if isinstance(input, features.Image): + output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) + return MultiCropResult(features.Image.new_like(input, o) for o in output) + elif is_simple_tensor(input): + return MultiCropResult(F.ten_crop_image_tensor(input, self.size)) + elif isinstance(input, PIL.Image.Image): + return MultiCropResult(F.ten_crop_image_pil(input, self.size)) + else: + return input + + def forward(self, *inputs: Any) -> Any: + sample = inputs if len(inputs) > 1 else inputs[0] + if has_any(sample, features.BoundingBox, features.SegmentationMask): + raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") + return super().forward(sample) + + +class BatchMultiCrop(Transform): + def forward(self, *inputs: Any) -> Any: + # This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one + # significant difference: + # Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from + # the sequence case. + def apply_recursively(obj: Any) -> Any: + if isinstance(obj, MultiCropResult): + crops = obj + if isinstance(obj[0], PIL.Image.Image): + crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] + + batch = torch.stack(crops) + + if isinstance(obj[0], features.Image): + batch = features.Image.new_like(obj[0], batch) + + return batch + elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): + return [apply_recursively(item) for item in obj] + elif isinstance(obj, collections.abc.Mapping): + return {key: apply_recursively(item) for key, item in obj.items()} + else: + return obj + + return apply_recursively(inputs if len(inputs) > 1 else inputs[0]) diff --git a/torchvision/prototype/transforms/functional/__init__.py b/torchvision/prototype/transforms/functional/__init__.py index e3fe60a7919..c0825784f66 100644 --- a/torchvision/prototype/transforms/functional/__init__.py +++ b/torchvision/prototype/transforms/functional/__init__.py @@ -60,6 +60,10 @@ perspective_image_pil, vertical_flip_image_tensor, vertical_flip_image_pil, + five_crop_image_tensor, + five_crop_image_pil, + ten_crop_image_tensor, + ten_crop_image_pil, ) from ._misc import normalize_image_tensor, gaussian_blur_image_tensor from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 080fe5da891..6c9309749af 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -314,3 +314,79 @@ def resized_crop_image_pil( ) -> PIL.Image.Image: img = crop_image_pil(img, top, left, height, width) return resize_image_pil(img, size, interpolation=interpolation) + + +def _parse_five_crop_size(size: List[int]) -> List[int]: + if isinstance(size, numbers.Number): + size = (int(size), int(size)) + elif isinstance(size, (tuple, list)) and len(size) == 1: + size = (size[0], size[0]) # type: ignore[assignment] + + if len(size) != 2: + raise ValueError("Please provide only two dimensions (h, w) for size.") + + return size + + +def five_crop_image_tensor( + img: torch.Tensor, size: List[int] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + crop_height, crop_width = _parse_five_crop_size(size) + _, image_height, image_width = get_dimensions_image_tensor(img) + + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop_image_tensor(img, 0, 0, crop_height, crop_width) + tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_tensor(img, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +def five_crop_image_pil( + img: PIL.Image.Image, size: List[int] +) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]: + crop_height, crop_width = _parse_five_crop_size(size) + _, image_height, image_width = get_dimensions_image_pil(img) + + if crop_width > image_width or crop_height > image_height: + msg = "Requested crop size {} is bigger than input size {}" + raise ValueError(msg.format(size, (image_height, image_width))) + + tl = crop_image_pil(img, 0, 0, crop_height, crop_width) + tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width) + bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width) + br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width) + center = center_crop_image_pil(img, [crop_height, crop_width]) + + return tl, tr, bl, br, center + + +def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: + tl, tr, bl, br, center = five_crop_image_tensor(img, size) + + if vertical_flip: + img = vertical_flip_image_tensor(img) + else: + img = horizontal_flip_image_tensor(img) + + tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size) + + return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] + + +def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: + tl, tr, bl, br, center = five_crop_image_pil(img, size) + + if vertical_flip: + img = vertical_flip_image_pil(img) + else: + img = horizontal_flip_image_pil(img) + + tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size) + + return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]