Skip to content

Commit 01175db

Browse files
committed
Unified FiveCrop and F.five_crop
1 parent 69cb9c5 commit 01175db

File tree

3 files changed

+85
-19
lines changed

3 files changed

+85
-19
lines changed

test/test_transforms_tensor.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,45 @@ def test_center_crop(self):
118118
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
119119
)
120120

121+
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
122+
if fn_kwargs is None:
123+
fn_kwargs = {}
124+
if meth_kwargs is None:
125+
meth_kwargs = {}
126+
tensor, pil_img = self._create_data(height=20, width=20)
127+
transformed_t_list = getattr(F, func)(tensor, **fn_kwargs)
128+
transformed_p_list = getattr(F, func)(pil_img, **fn_kwargs)
129+
self.assertEqual(len(transformed_t_list), len(transformed_p_list))
130+
self.assertEqual(len(transformed_t_list), out_length)
131+
for transformed_tensor, transformed_pil_img in zip(transformed_t_list, transformed_p_list):
132+
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
133+
134+
scripted_fn = torch.jit.script(getattr(F, func))
135+
transformed_t_list_script = scripted_fn(tensor.detach().clone(), **fn_kwargs)
136+
self.assertEqual(len(transformed_t_list), len(transformed_t_list_script))
137+
self.assertEqual(len(transformed_t_list_script), out_length)
138+
for transformed_tensor, transformed_tensor_script in zip(transformed_t_list, transformed_t_list_script):
139+
self.assertTrue(transformed_tensor.equal(transformed_tensor_script),
140+
msg="{} vs {}".format(transformed_tensor, transformed_tensor_script))
141+
142+
# test for class interface
143+
f = getattr(T, method)(**meth_kwargs)
144+
scripted_fn = torch.jit.script(f)
145+
output = scripted_fn(tensor)
146+
self.assertEqual(len(output), len(transformed_t_list_script))
147+
148+
def test_five_crop(self):
149+
fn_kwargs = {"size": (5,)}
150+
meth_kwargs = {"size": (5, )}
151+
self._test_geom_op_list_output(
152+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
153+
)
154+
fn_kwargs = {"size": (4, 5)}
155+
meth_kwargs = {"size": (4, 5)}
156+
self._test_geom_op_list_output(
157+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
158+
)
159+
121160

122161
if __name__ == '__main__':
123162
unittest.main()

torchvision/transforms/functional.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
import torch
1212
from torch import Tensor
13-
from torch.jit.annotations import List
13+
from torch.jit.annotations import List, Tuple
1414

1515
try:
1616
import accimage
@@ -423,13 +423,15 @@ def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
423423
img (PIL Image or Tensor): Image to be cropped.
424424
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int
425425
it is used for both directions
426+
426427
Returns:
427428
PIL Image or Tensor: Cropped image.
428429
"""
429430
if isinstance(output_size, numbers.Number):
430431
output_size = (int(output_size), int(output_size))
431432
elif isinstance(output_size, (tuple, list)) and len(output_size) == 1:
432433
output_size = (output_size[0], output_size[0])
434+
433435
image_width, image_height = _get_image_size(img)
434436
crop_height, crop_width = output_size
435437

@@ -589,8 +591,10 @@ def vflip(img: Tensor) -> Tensor:
589591
return F_t.vflip(img)
590592

591593

592-
def five_crop(img, size):
593-
"""Crop the given PIL Image into four corners and the central crop.
594+
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
595+
"""Crop the given image into four corners and the central crop.
596+
The image can be a PIL Image or a torch Tensor, in which case it is expected
597+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
594598
595599
.. Note::
596600
This transform returns a tuple of images and there may be a
@@ -607,22 +611,26 @@ def five_crop(img, size):
607611
"""
608612
if isinstance(size, numbers.Number):
609613
size = (int(size), int(size))
610-
else:
611-
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
614+
elif isinstance(size, (tuple, list)) and len(size) == 1:
615+
size = (size[0], size[0])
616+
617+
if len(size) != 2:
618+
raise ValueError("Please provide only two dimensions (h, w) for size.")
612619

613-
image_width, image_height = img.size
620+
image_width, image_height = _get_image_size(img)
614621
crop_height, crop_width = size
615622
if crop_width > image_width or crop_height > image_height:
616623
msg = "Requested crop size {} is bigger than input size {}"
617624
raise ValueError(msg.format(size, (image_height, image_width)))
618625

619-
tl = img.crop((0, 0, crop_width, crop_height))
620-
tr = img.crop((image_width - crop_width, 0, image_width, crop_height))
621-
bl = img.crop((0, image_height - crop_height, crop_width, image_height))
622-
br = img.crop((image_width - crop_width, image_height - crop_height,
623-
image_width, image_height))
624-
center = center_crop(img, (crop_height, crop_width))
625-
return (tl, tr, bl, br, center)
626+
tl = crop(img, 0, 0, crop_height, crop_width)
627+
tr = crop(img, 0, image_width - crop_width, crop_height, crop_width)
628+
bl = crop(img, image_height - crop_height, 0, crop_height, crop_width)
629+
br = crop(img, image_height - crop_height, image_width - crop_width, crop_height, crop_width)
630+
631+
center = center_crop(img, [crop_height, crop_width])
632+
633+
return tl, tr, bl, br, center
626634

627635

628636
def ten_crop(img, size, vertical_flip=False):

torchvision/transforms/transforms.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ class CenterCrop(torch.nn.Module):
260260
Args:
261261
size (sequence or int): Desired output size of the crop. If size is an
262262
int instead of sequence like (h, w), a square crop (size, size) is
263-
made. For scripted operation please use a list: (size, ) or (size_x, size_y)
263+
made. For scripted operation, please use a list: (size, ) or (size_x, size_y)
264264
"""
265265

266266
def __init__(self, size):
@@ -270,6 +270,9 @@ def __init__(self, size):
270270
elif isinstance(size, (tuple, list)) and len(size) == 1:
271271
self.size = (size[0], size[0])
272272
else:
273+
if len(size) != 2:
274+
raise ValueError("Please provide only two dimensions (h, w) for size.")
275+
273276
self.size = size
274277

275278
def forward(self, img):
@@ -572,7 +575,7 @@ def __repr__(self):
572575

573576

574577
class RandomVerticalFlip(torch.nn.Module):
575-
"""Vertically flip the given PIL Image randomly with a given probability.
578+
"""Vertically flip the given image randomly with a given probability.
576579
The image can be a PIL Image or a torch Tensor, in which case it is expected
577580
to have [..., H, W] shape, where ... means an arbitrary number of leading
578581
dimensions
@@ -769,8 +772,11 @@ def __init__(self, *args, **kwargs):
769772
super(RandomSizedCrop, self).__init__(*args, **kwargs)
770773

771774

772-
class FiveCrop(object):
773-
"""Crop the given PIL Image into four corners and the central crop
775+
class FiveCrop(torch.nn.Module):
776+
"""Crop the given image into four corners and the central crop.
777+
The image can be a PIL Image or a torch Tensor, in which case it is expected
778+
to have [..., H, W] shape, where ... means an arbitrary number of leading
779+
dimensions
774780
775781
.. Note::
776782
This transform returns a tuple of images and there may be a mismatch in the number of
@@ -780,6 +786,7 @@ class FiveCrop(object):
780786
Args:
781787
size (sequence or int): Desired output size of the crop. If size is an ``int``
782788
instead of sequence like (h, w), a square crop of size (size, size) is made.
789+
For scripted operation, please use a list: (size, ) or (size_x, size_y)
783790
784791
Example:
785792
>>> transform = Compose([
@@ -794,14 +801,26 @@ class FiveCrop(object):
794801
"""
795802

796803
def __init__(self, size):
804+
super().__init__()
797805
self.size = size
798806
if isinstance(size, numbers.Number):
799807
self.size = (int(size), int(size))
808+
elif isinstance(size, (tuple, list)) and len(size) == 1:
809+
self.size = (size[0], size[0])
800810
else:
801-
assert len(size) == 2, "Please provide only two dimensions (h, w) for size."
811+
if len(size) != 2:
812+
raise ValueError("Please provide only two dimensions (h, w) for size.")
813+
802814
self.size = size
803815

804-
def __call__(self, img):
816+
def forward(self, img):
817+
"""
818+
Args:
819+
img (PIL Image or Tensor): Image to be cropped.
820+
821+
Returns:
822+
PIL Image or Tensor: Cropped image.
823+
"""
805824
return F.five_crop(img, self.size)
806825

807826
def __repr__(self):

0 commit comments

Comments
 (0)