Skip to content

Commit 94e29a4

Browse files
committed
Unified TenCrop and F.ten_crop
1 parent b356e8b commit 94e29a4

File tree

3 files changed

+52
-14
lines changed

3 files changed

+52
-14
lines changed

test/test_transforms_tensor.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,6 +195,24 @@ def test_five_crop(self):
195195
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
196196
)
197197

198+
def test_ten_crop(self):
199+
fn_kwargs = meth_kwargs = {"size": (5,)}
200+
self._test_geom_op_list_output(
201+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
202+
)
203+
fn_kwargs = meth_kwargs = {"size": [5, ]}
204+
self._test_geom_op_list_output(
205+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
206+
)
207+
fn_kwargs = meth_kwargs = {"size": (4, 5)}
208+
self._test_geom_op_list_output(
209+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
210+
)
211+
fn_kwargs = meth_kwargs = {"size": [4, 5]}
212+
self._test_geom_op_list_output(
213+
"ten_crop", "TenCrop", out_length=10, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
214+
)
215+
198216

199217
if __name__ == '__main__':
200218
unittest.main()

torchvision/transforms/functional.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -636,19 +636,22 @@ def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Ten
636636
return tl, tr, bl, br, center
637637

638638

639-
def ten_crop(img, size, vertical_flip=False):
640-
"""Generate ten cropped images from the given PIL Image.
641-
Crop the given PIL Image into four corners and the central crop plus the
639+
def ten_crop(img: Tensor, size: List[int], vertical_flip: bool = False) -> List[Tensor]:
640+
"""Generate ten cropped images from the given image.
641+
Crop the given image into four corners and the central crop plus the
642642
flipped version of these (horizontal flipping is used by default).
643+
The image can be a PIL Image or a Tensor, in which case it is expected
644+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
643645
644646
.. Note::
645647
This transform returns a tuple of images and there may be a
646648
mismatch in the number of inputs and targets your ``Dataset`` returns.
647649
648650
Args:
651+
img (PIL Image or Tensor): Image to be cropped.
649652
size (sequence or int): Desired output size of the crop. If size is an
650653
int instead of sequence like (h, w), a square crop (size, size) is
651-
made.
654+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
652655
vertical_flip (bool): Use vertical flipping instead of horizontal
653656
654657
Returns:
@@ -658,8 +661,11 @@ def ten_crop(img, size, vertical_flip=False):
658661
"""
659662
if isinstance(size, numbers.Number):
660663
size = (int(size), int(size))
661-
else:
662-
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
664+
elif isinstance(size, (tuple, list)) and len(size) == 1:
665+
size = (size[0], size[0])
666+
667+
if len(size) != 2:
668+
raise ValueError("Please provide only two dimensions (h, w) for size.")
663669

664670
first_five = five_crop(img, size)
665671

torchvision/transforms/transforms.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -826,17 +826,20 @@ def forward(self, img):
826826
img (PIL Image or Tensor): Image to be cropped.
827827
828828
Returns:
829-
PIL Image or Tensor: Cropped image.
829+
tuple of 5 images. Image can be PIL Image or Tensor
830830
"""
831831
return F.five_crop(img, self.size)
832832

833833
def __repr__(self):
834834
return self.__class__.__name__ + '(size={0})'.format(self.size)
835835

836836

837-
class TenCrop(object):
838-
"""Crop the given PIL Image into four corners and the central crop plus the flipped version of
839-
these (horizontal flipping is used by default)
837+
class TenCrop(torch.nn.Module):
838+
"""Crop the given image into four corners and the central crop plus the flipped version of
839+
these (horizontal flipping is used by default).
840+
The image can be a PIL Image or a Tensor, in which case it is expected
841+
to have [..., H, W] shape, where ... means an arbitrary number of leading
842+
dimensions
840843
841844
.. Note::
842845
This transform returns a tuple of images and there may be a mismatch in the number of
@@ -846,7 +849,7 @@ class TenCrop(object):
846849
Args:
847850
size (sequence or int): Desired output size of the crop. If size is an
848851
int instead of sequence like (h, w), a square crop (size, size) is
849-
made.
852+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
850853
vertical_flip (bool): Use vertical flipping instead of horizontal
851854
852855
Example:
@@ -862,15 +865,26 @@ class TenCrop(object):
862865
"""
863866

864867
def __init__(self, size, vertical_flip=False):
865-
self.size = size
868+
super().__init__()
866869
if isinstance(size, numbers.Number):
867870
self.size = (int(size), int(size))
871+
elif isinstance(size, Sequence) and len(size) == 1:
872+
self.size = (size[0], size[0])
868873
else:
869-
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
874+
if len(size) != 2:
875+
raise ValueError("Please provide only two dimensions (h, w) for size.")
876+
870877
self.size = size
871878
self.vertical_flip = vertical_flip
872879

873-
def __call__(self, img):
880+
def forward(self, img):
881+
"""
882+
Args:
883+
img (PIL Image or Tensor): Image to be cropped.
884+
885+
Returns:
886+
tuple of 10 images. Image can be PIL Image or Tensor
887+
"""
874888
return F.ten_crop(img, self.size, self.vertical_flip)
875889

876890
def __repr__(self):

0 commit comments

Comments
 (0)