Skip to content

Ten crop annotation #7254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion torchvision/prototype/transforms/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,18 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:

def _transform(
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
) -> Tuple[
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, all of this is ugly and verbose AF. We can also try to keep the List annotation (which we should have gone for in the first place).

ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
Comment on lines +238 to +247
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.... wow

]:
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)


Expand Down
78 changes: 63 additions & 15 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1964,8 +1964,6 @@ def five_crop(
if not torch.jit.is_scripting():
_log_api_usage_once(five_crop)

# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
# `ten_crop`
if torch.jit.is_scripting() or is_simple_tensor(inpt):
return five_crop_image_tensor(inpt, size)
elif isinstance(inpt, datapoints.Image):
Expand All @@ -1983,40 +1981,90 @@ def five_crop(
)


def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
tl, tr, bl, br, center = five_crop_image_tensor(image, size)
def ten_crop_image_tensor(
image: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
Comment on lines +1987 to +1996
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh it keeps on giving 😅

]:
non_flipped = five_crop_image_tensor(image, size)

if vertical_flip:
image = vertical_flip_image_tensor(image)
else:
image = horizontal_flip_image_tensor(image)

tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size)
flipped = five_crop_image_tensor(image, size)

return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
return non_flipped + flipped


@torch.jit.unused
def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
tl, tr, bl, br, center = five_crop_image_pil(image, size)
def ten_crop_image_pil(
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
) -> Tuple[
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
PIL.Image.Image,
]:
non_flipped = five_crop_image_pil(image, size)

if vertical_flip:
image = vertical_flip_image_pil(image)
else:
image = horizontal_flip_image_pil(image)

tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size)

return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]


def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
flipped = five_crop_image_pil(image, size)

return non_flipped + flipped
Comment on lines +2032 to +2034
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the actual functional fix. v1 returned a tuple and so should v2.



def ten_crop_video(
video: torch.Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
torch.Tensor,
]:
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)


def ten_crop(
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
) -> Tuple[
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
ImageOrVideoTypeJIT,
]:
if not torch.jit.is_scripting():
_log_api_usage_once(ten_crop)

Expand Down
4 changes: 3 additions & 1 deletion torchvision/transforms/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -827,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
return tl, tr, bl, br, center


def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
def ten_crop(
img: Tensor, size: List[int], vertical_flip: bool = False
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
"""Generate ten cropped images from the given image.
Crop the given image into four corners and the central crop plus the
flipped version of these (horizontal flipping is used by default).
Expand Down
2 changes: 1 addition & 1 deletion torchvision/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,7 @@ class TenCrop(torch.nn.Module):

Example:
>>> transform = Compose([
>>> TenCrop(size), # this is a list of PIL Images
>>> TenCrop(size), # this is a tuple of PIL Images
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
>>> ])
>>> #In your test loop you can do the following:
Expand Down