-
Notifications
You must be signed in to change notification settings - Fork 7.1k
port FiveCrop and TenCrop to prototype API #5513
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
8916cdf
4673727
1d769c3
8986719
dec31cd
2aefd88
4302084
da1fa8f
2231c4a
f1a5003
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import functools | ||
import math | ||
import warnings | ||
from typing import Any, Dict, List, Union, Sequence, Tuple, cast | ||
|
@@ -6,6 +7,8 @@ | |
import torch | ||
from torchvision.prototype import features | ||
from torchvision.prototype.transforms import Transform, InterpolationMode, functional as F | ||
from torchvision.prototype.utils._internal import apply_recursively | ||
from torchvision.transforms.functional import pil_to_tensor | ||
from torchvision.transforms.transforms import _setup_size, _interpolation_modes_from_int | ||
|
||
from ._utils import query_image, get_image_dimensions, has_any | ||
|
@@ -168,3 +171,77 @@ def forward(self, *inputs: Any) -> Any: | |
if has_any(sample, features.BoundingBox, features.SegmentationMask): | ||
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") | ||
return super().forward(sample) | ||
|
||
|
||
class FiveCrop(Transform): | ||
def __init__(self, size: Union[int, Sequence[int]]) -> None: | ||
super().__init__() | ||
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") | ||
|
||
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: | ||
if isinstance(input, features.Image): | ||
output = F.five_crop_image_tensor(input, self.size) | ||
return F._FiveCropResult(*[features.Image.new_like(input, o) for o in output]) | ||
elif type(input) is torch.Tensor: | ||
return F.five_crop_image_tensor(input, self.size) | ||
elif isinstance(input, PIL.Image.Image): | ||
return F.five_crop_image_pil(input, self.size) | ||
else: | ||
return input | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
if has_any(sample, features.BoundingBox, features.SegmentationMask): | ||
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") | ||
return super().forward(sample) | ||
|
||
|
||
class TenCrop(Transform): | ||
def __init__(self, size: Union[int, Sequence[int]], vertical_flip: bool = False) -> None: | ||
super().__init__() | ||
self.size = _setup_size(size, error_msg="Please provide only two dimensions (h, w) for size.") | ||
self.vertical_flip = vertical_flip | ||
|
||
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: | ||
if isinstance(input, features.Image): | ||
output = F.ten_crop_image_tensor(input, self.size, vertical_flip=self.vertical_flip) | ||
return F._TenCropResult(*[features.Image.new_like(input, o) for o in output]) | ||
elif type(input) is torch.Tensor: | ||
return F.ten_crop_image_tensor(input, self.size) | ||
elif isinstance(input, PIL.Image.Image): | ||
return F.five_crop_image_pil(input, self.size) | ||
datumbox marked this conversation as resolved.
Show resolved
Hide resolved
|
||
else: | ||
return input | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
if has_any(sample, features.BoundingBox, features.SegmentationMask): | ||
raise TypeError(f"BoundingBox'es and SegmentationMask's are not supported by {type(self).__name__}()") | ||
return super().forward(sample) | ||
|
||
|
||
class BatchMultiCrop(Transform): | ||
pmeier marked this conversation as resolved.
Show resolved
Hide resolved
|
||
_MULTI_CROP_TYPES = (F._FiveCropResult, F._TenCropResult) | ||
|
||
def _transform(self, input: Any, params: Dict[str, Any]) -> Any: | ||
if isinstance(input, self._MULTI_CROP_TYPES): | ||
crops = input | ||
if isinstance(input[0], PIL.Image.Image): | ||
crops = [pil_to_tensor(crop) for crop in crops] # type: ignore[assignment] | ||
|
||
batch = torch.stack(crops) | ||
|
||
if isinstance(input[0], features.Image): | ||
batch = features.Image.new_like(input[0], batch) | ||
|
||
return batch | ||
else: | ||
return input | ||
|
||
def forward(self, *inputs: Any) -> Any: | ||
sample = inputs if len(inputs) > 1 else inputs[0] | ||
return apply_recursively( | ||
functools.partial(self._transform, params=self._get_params(sample)), | ||
sample, | ||
exclude_sequence_types=(str, *self._MULTI_CROP_TYPES), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need this exclude here, because named tuples by default would be recognized as sequence and thus we would only get the individual elements rather than everything at once. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think that's one more reason not to use named tuples. |
||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
TypeVar, | ||
Union, | ||
Optional, | ||
Type, | ||
) | ||
|
||
import numpy as np | ||
|
@@ -301,13 +302,42 @@ def read(self, size: int = -1) -> bytes: | |
return self._memory[slice(cursor, self.seek(offset, whence))].tobytes() | ||
|
||
|
||
def apply_recursively(fn: Callable, obj: Any) -> Any: | ||
def apply_recursively( | ||
fn: Callable, | ||
obj: Any, | ||
*, | ||
include_sequence_types: Collection[Type] = (collections.abc.Sequence,), | ||
# We explicitly exclude str's here since they are self-referential and would cause an infinite recursion loop: | ||
# "a" == "a"[0][0]... | ||
if isinstance(obj, collections.abc.Sequence) and not isinstance(obj, str): | ||
return [apply_recursively(fn, item) for item in obj] | ||
elif isinstance(obj, collections.abc.Mapping): | ||
return {key: apply_recursively(fn, item) for key, item in obj.items()} | ||
exclude_sequence_types: Collection[Type] = (str,), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need this addition to be able to exclude named tuples as sequences in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should start from the assumption that this is not needed and add it later. Starting with a simple solution first is a good prior. Please simplify as much as possible. |
||
include_mapping_types: Collection[Type] = (collections.abc.Mapping,), | ||
exclude_mapping_types: Collection[Type] = (), | ||
) -> Any: | ||
if isinstance(obj, tuple(include_sequence_types)) and not isinstance(obj, tuple(exclude_sequence_types)): | ||
return [ | ||
apply_recursively( | ||
fn, | ||
item, | ||
include_sequence_types=include_sequence_types, | ||
exclude_sequence_types=exclude_sequence_types, | ||
include_mapping_types=include_mapping_types, | ||
exclude_mapping_types=exclude_mapping_types, | ||
) | ||
for item in obj | ||
] | ||
|
||
if isinstance(obj, tuple(include_mapping_types)) and not isinstance(obj, tuple(exclude_mapping_types)): | ||
return { | ||
key: apply_recursively( | ||
fn, | ||
item, | ||
include_sequence_types=include_sequence_types, | ||
exclude_sequence_types=exclude_sequence_types, | ||
include_mapping_types=include_mapping_types, | ||
exclude_mapping_types=exclude_mapping_types, | ||
) | ||
for key, item in obj.items() | ||
} | ||
else: | ||
return fn(obj) | ||
|
||
|
Uh oh!
There was an error while loading. Please reload this page.