Skip to content

Commit 7039c2c

Browse files
pmeierdatumbox
andauthored
port FiveCrop and TenCrop to prototype API (#5513)
* port FiveCrop and TenCrop to prototype API * fix ten crop for pil * Update torchvision/prototype/transforms/_geometry.py Co-authored-by: Philip Meier <[email protected]> * simplify implementation * minor cleanup Co-authored-by: Vasilis Vryniotis <[email protected]>
1 parent 79892d3 commit 7039c2c

File tree

4 files changed

+169
-1
lines changed

4 files changed

+169
-1
lines changed

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ._augment import RandomErasing, RandomMixup, RandomCutmix
88
from ._auto_augment import RandAugment, TrivialAugmentWide, AutoAugment, AugMix
99
from ._container import Compose, RandomApply, RandomChoice, RandomOrder
10-
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop
10+
from ._geometry import HorizontalFlip, Resize, CenterCrop, RandomResizedCrop, FiveCrop, TenCrop, BatchMultiCrop
1111
from ._meta import ConvertBoundingBoxFormat, ConvertImageDtype, ConvertImageColorSpace
1212
from ._misc import Identity, Normalize, ToDtype, Lambda
1313
from ._presets import CocoEval, ImageNetEval, VocEval, Kinect400Eval, RaftEval

torchvision/prototype/transforms/_geometry.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections.abc
12
import math
23
import warnings
34
from typing import Any, Dict, List, Union, Sequence, Tuple, cast
@@ -6,6 +7,7 @@
67
import torch
78
from torchvision.prototype import features
89
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F
10+
from torchvision.transforms.functional import pil_to_tensor
911
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int
1012

1113
from ._utils import query_image, get_image_dimensions, has_any, is_simple_tensor
@@ -168,3 +170,89 @@ def forward(self, *inputs: Any) -> Any:
168170
if has_any(sample, features.BoundingBox, features.SegmentationMask):
169171
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
170172
return super().forward(sample)
173+
174+
175+
class MultiCropResult(list):
176+
"""Helper class for :class:`~torchvision.prototype.transforms.BatchMultiCrop`.
177+
178+
Outputs of multi crop transforms such as :class:`~torchvision.prototype.transforms.FiveCrop` and
179+
`:class:`~torchvision.prototype.transforms.TenCrop` should be wrapped in this in order to be batched correctly by
180+
:class:`~torchvision.prototype.transforms.BatchMultiCrop`.
181+
"""
182+
183+
pass
184+
185+
186+
class FiveCrop(Transform):
187+
def __init__(self, size: Union[int, Sequence[int]]) -> None:
188+
super().__init__()
189+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
190+
191+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
192+
if isinstance(input, features.Image):
193+
output = F.five_crop_image_tensor(input, self.size)
194+
return MultiCropResult(features.Image.new_like(input, o) for o in output)
195+
elif is_simple_tensor(input):
196+
return MultiCropResult(F.five_crop_image_tensor(input, self.size))
197+
elif isinstance(input, PIL.Image.Image):
198+
return MultiCropResult(F.five_crop_image_pil(input, self.size))
199+
else:
200+
return input
201+
202+
def forward(self, *inputs: Any) -> Any:
203+
sample = inputs if len(inputs) > 1 else inputs[0]
204+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
205+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
206+
return super().forward(sample)
207+
208+
209+
class TenCrop(Transform):
210+
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None:
211+
super().__init__()
212+
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.")
213+
self.vertical_flip = vertical_flip
214+
215+
def _transform(self, input: Any, params: Dict[str, Any]) -> Any:
216+
if isinstance(input, features.Image):
217+
output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip)
218+
return MultiCropResult(features.Image.new_like(input, o) for o in output)
219+
elif is_simple_tensor(input):
220+
return MultiCropResult(F.ten_crop_image_tensor(input, self.size))
221+
elif isinstance(input, PIL.Image.Image):
222+
return MultiCropResult(F.ten_crop_image_pil(input, self.size))
223+
else:
224+
return input
225+
226+
def forward(self, *inputs: Any) -> Any:
227+
sample = inputs if len(inputs) > 1 else inputs[0]
228+
if has_any(sample, features.BoundingBox, features.SegmentationMask):
229+
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()")
230+
return super().forward(sample)
231+
232+
233+
class BatchMultiCrop(Transform):
234+
def forward(self, *inputs: Any) -> Any:
235+
# This is basically the functionality of `torchvision.prototype.utils._internal.apply_recursively` with one
236+
# significant difference:
237+
# Since we need multiple images to batch them together, we need to explicitly exclude `MultiCropResult` from
238+
# the sequence case.
239+
def apply_recursively(obj: Any) -> Any:
240+
if isinstance(obj, MultiCropResult):
241+
crops = obj
242+
if isinstance(obj[0], PIL.Image.Image):
243+
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment]
244+
245+
batch = torch.stack(crops)
246+
247+
if isinstance(obj[0], features.Image):
248+
batch = features.Image.new_like(obj[0], batch)
249+
250+
return batch
251+
elif isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str):
252+
return [apply_recursively(item) for item in obj]
253+
elif isinstance(obj, collections.abc.Mapping):
254+
return {key: apply_recursively(item) for key, item in obj.items()}
255+
else:
256+
return obj
257+
258+
return apply_recursively(inputs if len(inputs) > 1 else inputs[0])

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,10 @@
6060
perspective_image_pil,
6161
vertical_flip_image_tensor,
6262
vertical_flip_image_pil,
63+
five_crop_image_tensor,
64+
five_crop_image_pil,
65+
ten_crop_image_tensor,
66+
ten_crop_image_pil,
6367
)
6468
from ._misc import normalize_image_tensor, gaussian_blur_image_tensor
6569
from ._type_conversion import decode_image_with_pil, decode_video_with_av, label_to_one_hot

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,3 +314,79 @@ def resized_crop_image_pil(
314314
) -> PIL.Image.Image:
315315
img = crop_image_pil(img, top, left, height, width)
316316
return resize_image_pil(img, size, interpolation=interpolation)
317+
318+
319+
def _parse_five_crop_size(size: List[int]) -> List[int]:
320+
if isinstance(size, numbers.Number):
321+
size = (int(size), int(size))
322+
elif isinstance(size, (tuple, list)) and len(size) == 1:
323+
size = (size[0], size[0]) # type: ignore[assignment]
324+
325+
if len(size) != 2:
326+
raise ValueError("Please provide only two dimensions (h, w) for size.")
327+
328+
return size
329+
330+
331+
def five_crop_image_tensor(
332+
img: torch.Tensor, size: List[int]
333+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
334+
crop_height, crop_width = _parse_five_crop_size(size)
335+
_, image_height, image_width = get_dimensions_image_tensor(img)
336+
337+
if crop_width > image_width or crop_height > image_height:
338+
msg = "Requested crop size {} is bigger than input size {}"
339+
raise ValueError(msg.format(size, (image_height, image_width)))
340+
341+
tl = crop_image_tensor(img, 0, 0, crop_height, crop_width)
342+
tr = crop_image_tensor(img, 0, image_width - crop_width, crop_height, crop_width)
343+
bl = crop_image_tensor(img, image_height - crop_height, 0, crop_height, crop_width)
344+
br = crop_image_tensor(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
345+
center = center_crop_image_tensor(img, [crop_height, crop_width])
346+
347+
return tl, tr, bl, br, center
348+
349+
350+
def five_crop_image_pil(
351+
img: PIL.Image.Image, size: List[int]
352+
) -> Tuple[PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image, PIL.Image.Image]:
353+
crop_height, crop_width = _parse_five_crop_size(size)
354+
_, image_height, image_width = get_dimensions_image_pil(img)
355+
356+
if crop_width > image_width or crop_height > image_height:
357+
msg = "Requested crop size {} is bigger than input size {}"
358+
raise ValueError(msg.format(size, (image_height, image_width)))
359+
360+
tl = crop_image_pil(img, 0, 0, crop_height, crop_width)
361+
tr = crop_image_pil(img, 0, image_width - crop_width, crop_height, crop_width)
362+
bl = crop_image_pil(img, image_height - crop_height, 0, crop_height, crop_width)
363+
br = crop_image_pil(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
364+
center = center_crop_image_pil(img, [crop_height, crop_width])
365+
366+
return tl, tr, bl, br, center
367+
368+
369+
def ten_crop_image_tensor(img: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
370+
tl, tr, bl, br, center = five_crop_image_tensor(img, size)
371+
372+
if vertical_flip:
373+
img = vertical_flip_image_tensor(img)
374+
else:
375+
img = horizontal_flip_image_tensor(img)
376+
377+
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(img, size)
378+
379+
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
380+
381+
382+
def ten_crop_image_pil(img: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
383+
tl, tr, bl, br, center = five_crop_image_pil(img, size)
384+
385+
if vertical_flip:
386+
img = vertical_flip_image_pil(img)
387+
else:
388+
img = horizontal_flip_image_pil(img)
389+
390+
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(img, size)
391+
392+
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]

0 commit comments

Comments
 (0)