From 933b0f71484a24a1c2ed13ab92a2fd64690fd2f3 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 14:05:10 +0100 Subject: [PATCH 1/2] fix ten_crop annotation --- torchvision/transforms/functional.py | 4 +++- torchvision/transforms/transforms.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index beeb02cd915..c5b2a71d0d7 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -827,7 +827,9 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten return tl, tr, bl, br, center -def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]: +def ten_crop( + img: Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]: """Generate ten cropped images from the given image. Crop the given image into four corners and the central crop plus the flipped version of these (horizontal flipping is used by default). diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index d7858353be9..90cb0374eee 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -1049,7 +1049,7 @@ class TenCrop(torch.nn.Module): Example: >>> transform = Compose([ - >>> TenCrop(size), # this is a list of PIL Images + >>> TenCrop(size), # this is a tuple of PIL Images >>> Lambda(lambda crops: torch.stack([PILToTensor()(crop) for crop in crops])) # returns a 4D tensor >>> ]) >>> #In your test loop you can do the following: From 7ae1f9e5bac1ebe51311ad77852ce69658922cf0 Mon Sep 17 00:00:00 2001 From: Philip Meier Date: Wed, 15 Feb 2023 14:10:38 +0100 Subject: [PATCH 2/2] fix v2 --- torchvision/prototype/transforms/_geometry.py | 13 +++- .../transforms/functional/_geometry.py | 78 +++++++++++++++---- 2 files changed, 75 insertions(+), 16 deletions(-) diff --git a/torchvision/prototype/transforms/_geometry.py b/torchvision/prototype/transforms/_geometry.py index ffabb91471c..69238760be5 100644 --- a/torchvision/prototype/transforms/_geometry.py +++ b/torchvision/prototype/transforms/_geometry.py @@ -234,7 +234,18 @@ def _check_inputs(self, flat_inputs: List[Any]) -> None: def _transform( self, inpt: Union[datapoints.ImageType, datapoints.VideoType], params: Dict[str, Any] - ) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: + ) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ]: return F.ten_crop(inpt, self.size, vertical_flip=self.vertical_flip) diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 814697f03a3..9f9d5f4e705 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -1959,8 +1959,6 @@ def five_crop( if not torch.jit.is_scripting(): _log_api_usage_once(five_crop) - # TODO: consider breaking BC here to return List[datapoints.ImageTypeJIT/VideoTypeJIT] to align this op with - # `ten_crop` if torch.jit.is_scripting() or is_simple_tensor(inpt): return five_crop_image_tensor(inpt, size) elif isinstance(inpt, datapoints.Image): @@ -1978,40 +1976,90 @@ def five_crop( ) -def ten_crop_image_tensor(image: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: - tl, tr, bl, br, center = five_crop_image_tensor(image, size) +def ten_crop_image_tensor( + image: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: + non_flipped = five_crop_image_tensor(image, size) if vertical_flip: image = vertical_flip_image_tensor(image) else: image = horizontal_flip_image_tensor(image) - tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_tensor(image, size) + flipped = five_crop_image_tensor(image, size) - return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] + return non_flipped + flipped @torch.jit.unused -def ten_crop_image_pil(image: PIL.Image.Image, size: List[int], vertical_flip: bool = False) -> List[PIL.Image.Image]: - tl, tr, bl, br, center = five_crop_image_pil(image, size) +def ten_crop_image_pil( + image: PIL.Image.Image, size: List[int], vertical_flip: bool = False +) -> Tuple[ + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, + PIL.Image.Image, +]: + non_flipped = five_crop_image_pil(image, size) if vertical_flip: image = vertical_flip_image_pil(image) else: image = horizontal_flip_image_pil(image) - tl_flip, tr_flip, bl_flip, br_flip, center_flip = five_crop_image_pil(image, size) - - return [tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip] - - -def ten_crop_video(video: torch.Tensor, size: List[int], vertical_flip: bool = False) -> List[torch.Tensor]: + flipped = five_crop_image_pil(image, size) + + return non_flipped + flipped + + +def ten_crop_video( + video: torch.Tensor, size: List[int], vertical_flip: bool = False +) -> Tuple[ + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, + torch.Tensor, +]: return ten_crop_image_tensor(video, size, vertical_flip=vertical_flip) def ten_crop( inpt: Union[datapoints.ImageTypeJIT, datapoints.VideoTypeJIT], size: List[int], vertical_flip: bool = False -) -> Union[List[datapoints.ImageTypeJIT], List[datapoints.VideoTypeJIT]]: +) -> Tuple[ + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, + ImageOrVideoTypeJIT, +]: if not torch.jit.is_scripting(): _log_api_usage_once(ten_crop)