Skip to content

Commit 316cc25

Browse files
authored
Ten crop annotation (#7254)
1 parent f0b7000 commit 316cc25

File tree

4 files changed

+79
-18
lines changed

4 files changed

+79
-18
lines changed

torchvision/prototype/transforms/_geometry.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,18 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None:
234234

235235
def _transform(
236236
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any]
237-
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
237+
) -> Tuple[
238+
ImageOrVideoTypeJIT,
239+
ImageOrVideoTypeJIT,
240+
ImageOrVideoTypeJIT,
241+
ImageOrVideoTypeJIT,
242+
ImageOrVideoTypeJIT,
243+
ImageOrVideoTypeJIT,
244+
ImageOrVideoTypeJIT,
245+
ImageOrVideoTypeJIT,
246+
ImageOrVideoTypeJIT,
247+
ImageOrVideoTypeJIT,
248+
]:
238249
return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip)
239250

240251

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1964,8 +1964,6 @@ def five_crop(
19641964
if not torch.jit.is_scripting():
19651965
_log_api_usage_once(five_crop)
19661966

1967-
# TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with
1968-
# `ten_crop`
19691967
if torch.jit.is_scripting() or is_simple_tensor(inpt):
19701968
return five_crop_image_tensor(inpt, size)
19711969
elif isinstance(inpt, datapoints.Image):
@@ -1983,40 +1981,90 @@ def five_crop(
19831981
)
19841982

19851983

1986-
def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
1987-
tl, tr, bl, br, center = five_crop_image_tensor(image, size)
1984+
def ten_crop_image_tensor(
1985+
image: torch.Tensor, size: List[int], vertical_flip: bool = False
1986+
) -> Tuple[
1987+
torch.Tensor,
1988+
torch.Tensor,
1989+
torch.Tensor,
1990+
torch.Tensor,
1991+
torch.Tensor,
1992+
torch.Tensor,
1993+
torch.Tensor,
1994+
torch.Tensor,
1995+
torch.Tensor,
1996+
torch.Tensor,
1997+
]:
1998+
non_flipped = five_crop_image_tensor(image, size)
19881999

19892000
if vertical_flip:
19902001
image = vertical_flip_image_tensor(image)
19912002
else:
19922003
image = horizontal_flip_image_tensor(image)
19932004

1994-
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size)
2005+
flipped = five_crop_image_tensor(image, size)
19952006

1996-
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
2007+
return non_flipped + flipped
19972008

19982009

19992010
@torch.jit.unused
2000-
def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]:
2001-
tl, tr, bl, br, center = five_crop_image_pil(image, size)
2011+
def ten_crop_image_pil(
2012+
image: PIL.Image.Image, size: List[int], vertical_flip: bool = False
2013+
) -> Tuple[
2014+
PIL.Image.Image,
2015+
PIL.Image.Image,
2016+
PIL.Image.Image,
2017+
PIL.Image.Image,
2018+
PIL.Image.Image,
2019+
PIL.Image.Image,
2020+
PIL.Image.Image,
2021+
PIL.Image.Image,
2022+
PIL.Image.Image,
2023+
PIL.Image.Image,
2024+
]:
2025+
non_flipped = five_crop_image_pil(image, size)
20022026

20032027
if vertical_flip:
20042028
image = vertical_flip_image_pil(image)
20052029
else:
20062030
image = horizontal_flip_image_pil(image)
20072031

2008-
tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size)
2009-
2010-
return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip]
2011-
2012-
2013-
def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]:
2032+
flipped = five_crop_image_pil(image, size)
2033+
2034+
return non_flipped + flipped
2035+
2036+
2037+
def ten_crop_video(
2038+
video: torch.Tensor, size: List[int], vertical_flip: bool = False
2039+
) -> Tuple[
2040+
torch.Tensor,
2041+
torch.Tensor,
2042+
torch.Tensor,
2043+
torch.Tensor,
2044+
torch.Tensor,
2045+
torch.Tensor,
2046+
torch.Tensor,
2047+
torch.Tensor,
2048+
torch.Tensor,
2049+
torch.Tensor,
2050+
]:
20142051
return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip)
20152052

20162053

20172054
def ten_crop(
20182055
inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False
2019-
) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]:
2056+
) -> Tuple[
2057+
ImageOrVideoTypeJIT,
2058+
ImageOrVideoTypeJIT,
2059+
ImageOrVideoTypeJIT,
2060+
ImageOrVideoTypeJIT,
2061+
ImageOrVideoTypeJIT,
2062+
ImageOrVideoTypeJIT,
2063+
ImageOrVideoTypeJIT,
2064+
ImageOrVideoTypeJIT,
2065+
ImageOrVideoTypeJIT,
2066+
ImageOrVideoTypeJIT,
2067+
]:
20202068
if not torch.jit.is_scripting():
20212069
_log_api_usage_once(ten_crop)
20222070

torchvision/transforms/functional.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
827827
return tl, tr, bl, br, center
828828

829829

830-
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
830+
def ten_crop(
831+
img: Tensor, size: List[int], vertical_flip: bool = False
832+
) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
831833
"""Generate ten cropped images from the given image.
832834
Crop the given image into four corners and the central crop plus the
833835
flipped version of these (horizontal flipping is used by default).

torchvision/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1049,7 +1049,7 @@ class TenCrop(torch.nn.Module):
10491049
10501050
Example:
10511051
>>> transform = Compose([
1052-
>>> TenCrop(size), # this is a list of PIL Images
1052+
>>> TenCrop(size), # this is a tuple of PIL Images
10531053
>>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor
10541054
>>> ])
10551055
>>> #In your test loop you can do the following:

0 commit comments

Comments
 (0)