Skip to content

Commit ad7b2be

Browse files
committed
[WIP] Unified Tensor/PIL crop
1 parent cd2b7f0 commit ad7b2be

File tree

5 files changed

+128
-75
lines changed

5 files changed

+128
-75
lines changed

test/test_transforms_tensor.py

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,30 @@ def compareTensorToPIL(self, tensor, pil_image):
1818
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
1919
self.assertTrue(tensor.equal(pil_tensor))
2020

21-
def _test_flip(self, func, method):
22-
tensor, pil_img = self._create_data()
23-
flip_tensor = getattr(F, func)(tensor)
24-
flip_pil_img = getattr(F, func)(pil_img)
25-
self.compareTensorToPIL(flip_tensor, flip_pil_img)
21+
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
22+
if fn_kwargs is None:
23+
fn_kwargs = {}
24+
if meth_kwargs is None:
25+
meth_kwargs = {}
26+
tensor, pil_img = self._create_data(height=10, width=10)
27+
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
28+
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
29+
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
2630

2731
scripted_fn = torch.jit.script(getattr(F, func))
28-
flip_tensor_script = scripted_fn(tensor)
29-
self.assertTrue(flip_tensor.equal(flip_tensor_script))
32+
transformed_tensor_script = scripted_fn(tensor, **fn_kwargs)
33+
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
3034

3135
# test for class interface
32-
f = getattr(T, method)()
36+
f = getattr(T, method)(**meth_kwargs)
3337
scripted_fn = torch.jit.script(f)
3438
scripted_fn(tensor)
3539

3640
def test_random_horizontal_flip(self):
37-
self._test_flip('hflip', 'RandomHorizontalFlip')
41+
self._test_geom_op('hflip', 'RandomHorizontalFlip')
3842

3943
def test_random_vertical_flip(self):
40-
self._test_flip('vflip', 'RandomVerticalFlip')
44+
self._test_geom_op('vflip', 'RandomVerticalFlip')
4145

4246
def test_adjustments(self):
4347
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
@@ -65,6 +69,13 @@ def test_adjustments(self):
6569
self.assertLess(max_diff, 5 / 255 + 1e-5)
6670
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
6771

72+
def test_crop(self):
73+
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
74+
meth_kwargs = {"size": (4, 5), "padding": 4, "pad_if_needed": True, }
75+
self._test_geom_op(
76+
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
77+
)
78+
6879

6980
if __name__ == '__main__':
7081
unittest.main()

torchvision/transforms/functional.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,24 @@
1616
from . import functional_tensor as F_t
1717

1818

19-
def _is_pil_image(img):
20-
if accimage is not None:
21-
return isinstance(img, (Image.Image, accimage.Image))
22-
else:
23-
return isinstance(img, Image.Image)
19+
@torch.jit.export
20+
def _get_image_size(img):
21+
# type: (Tensor) -> List[int]
22+
if isinstance(img, torch.Tensor):
23+
return F_t._get_image_size(img)
24+
25+
return F_pil._get_image_size(img)
2426

2527

28+
@torch.jit.ignore
2629
def _is_numpy(img):
30+
# type: (Any) -> bool
2731
return isinstance(img, np.ndarray)
2832

2933

34+
@torch.jit.ignore
3035
def _is_numpy_image(img):
36+
# type: (Any) -> bool
3137
return img.ndim in {2, 3}
3238

3339

@@ -42,7 +48,7 @@ def to_tensor(pic):
4248
Returns:
4349
Tensor: Converted image.
4450
"""
45-
if not(_is_pil_image(pic) or _is_numpy(pic)):
51+
if not(F_pil._is_pil_image(pic) or _is_numpy(pic)):
4652
raise TypeError('pic should be PIL Image or ndarray. Got {}'.format(type(pic)))
4753

4854
if _is_numpy(pic) and not _is_numpy_image(pic):
@@ -97,7 +103,7 @@ def pil_to_tensor(pic):
97103
Returns:
98104
Tensor: Converted image.
99105
"""
100-
if not(_is_pil_image(pic)):
106+
if not(F_pil._is_pil_image(pic)):
101107
raise TypeError('pic should be PIL Image. Got {}'.format(type(pic)))
102108

103109
if accimage is not None and isinstance(pic, accimage.Image):
@@ -315,7 +321,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
315321
Returns:
316322
PIL Image: Resized image.
317323
"""
318-
if not _is_pil_image(img):
324+
if not F_pil._is_pil_image(img):
319325
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
320326
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)):
321327
raise TypeError('Got inappropriate size arg: {}'.format(size))
@@ -374,7 +380,7 @@ def pad(img, padding, fill=0, padding_mode='constant'):
374380
Returns:
375381
PIL Image: Padded image.
376382
"""
377-
if not _is_pil_image(img):
383+
if not F_pil._is_pil_image(img):
378384
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
379385

380386
if not isinstance(padding, (numbers.Number, tuple)):
@@ -436,23 +442,24 @@ def pad(img, padding, fill=0, padding_mode='constant'):
436442
return Image.fromarray(img)
437443

438444

439-
def crop(img, top, left, height, width):
445+
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
440446
"""Crop the given PIL Image.
441447
442448
Args:
443-
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
449+
img (PIL Image or torch.Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
444450
top (int): Vertical component of the top left corner of the crop box.
445451
left (int): Horizontal component of the top left corner of the crop box.
446452
height (int): Height of the crop box.
447453
width (int): Width of the crop box.
448454
449455
Returns:
450-
PIL Image: Cropped image.
456+
PIL Image or torch.Tensor: Cropped image.
451457
"""
452-
if not _is_pil_image(img):
453-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
454458

455-
return img.crop((left, top, left + width, top + height))
459+
if not isinstance(img, torch.Tensor):
460+
return F_pil.crop(img, top, left, height, width)
461+
462+
return F_t.crop(img, top, left, height, width)
456463

457464

458465
def center_crop(img, output_size):
@@ -491,7 +498,7 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
491498
Returns:
492499
PIL Image: Cropped image.
493500
"""
494-
assert _is_pil_image(img), 'img should be PIL Image'
501+
assert F_pil._is_pil_image(img), 'img should be PIL Image'
495502
img = crop(img, top, left, height, width)
496503
img = resize(img, size, interpolation)
497504
return img
@@ -501,13 +508,13 @@ def hflip(img: Tensor) -> Tensor:
501508
"""Horizontally flip the given PIL Image or torch Tensor.
502509
503510
Args:
504-
img (PIL Image or Torch Tensor): Image to be flipped. If img
511+
img (PIL Image or torch.Tensor): Image to be flipped. If img
505512
is a Tensor, it is expected to be in [..., H, W] format,
506513
where ... means it can have an arbitrary number of trailing
507514
dimensions.
508515
509516
Returns:
510-
PIL Image: Horizontally flipped image.
517+
PIL Image or torch.Tensor: Horizontally flipped image.
511518
"""
512519
if not isinstance(img, torch.Tensor):
513520
return F_pil.hflip(img)
@@ -593,7 +600,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
593600
PIL Image: Perspectively transformed Image.
594601
"""
595602

596-
if not _is_pil_image(img):
603+
if not F_pil._is_pil_image(img):
597604
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
598605

599606
opts = _parse_fill(fill, img, '5.0.0')
@@ -797,7 +804,7 @@ def adjust_gamma(img, gamma, gain=1):
797804
while gamma smaller than 1 make dark regions lighter.
798805
gain (float): The constant multiplier.
799806
"""
800-
if not _is_pil_image(img):
807+
if not F_pil._is_pil_image(img):
801808
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
802809

803810
if gamma < 0:
@@ -837,7 +844,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
837844
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
838845
839846
"""
840-
if not _is_pil_image(img):
847+
if not F_pil._is_pil_image(img):
841848
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
842849

843850
opts = _parse_fill(fill, img, '5.2.0')
@@ -918,7 +925,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
918925
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
919926
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
920927
"""
921-
if not _is_pil_image(img):
928+
if not F_pil._is_pil_image(img):
922929
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
923930

924931
assert isinstance(translate, (tuple, list)) and len(translate) == 2, \
@@ -945,7 +952,7 @@ def to_grayscale(img, num_output_channels=1):
945952
946953
if num_output_channels = 3 : returned image is 3 channel with r = g = b
947954
"""
948-
if not _is_pil_image(img):
955+
if not F_pil._is_pil_image(img):
949956
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
950957

951958
if num_output_channels == 1:

torchvision/transforms/functional_pil.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,27 @@
33
import accimage
44
except ImportError:
55
accimage = None
6-
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
6+
from PIL import Image, ImageOps, ImageEnhance
77
import numpy as np
88

99

1010
@torch.jit.unused
1111
def _is_pil_image(img):
12+
# type: (Any) -> bool
1213
if accimage is not None:
1314
return isinstance(img, (Image.Image, accimage.Image))
1415
else:
1516
return isinstance(img, Image.Image)
1617

1718

19+
@torch.jit.unused
20+
def _get_image_size(img):
21+
# type: (Any) -> List[int]
22+
if _is_pil_image(img):
23+
return img.size
24+
raise TypeError("Unexpected type {}".format(type(img)))
25+
26+
1827
@torch.jit.unused
1928
def hflip(img):
2029
"""Horizontally flip the given PIL Image.
@@ -152,3 +161,23 @@ def adjust_hue(img, hue_factor):
152161

153162
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
154163
return img
164+
165+
166+
@torch.jit.unused
167+
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
168+
"""Crop the given PIL Image.
169+
170+
Args:
171+
img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
172+
top (int): Vertical component of the top left corner of the crop box.
173+
left (int): Horizontal component of the top left corner of the crop box.
174+
height (int): Height of the crop box.
175+
width (int): Width of the crop box.
176+
177+
Returns:
178+
PIL Image: Cropped image.
179+
"""
180+
if not _is_pil_image(img):
181+
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
182+
183+
return img.crop((left, top, left + width, top + height))

torchvision/transforms/functional_tensor.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,18 @@
33
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
44

55

6-
def _is_tensor_a_torch_image(input):
7-
return input.ndim >= 2
6+
@torch.jit.export
7+
def _is_tensor_a_torch_image(x):
8+
# type: (Tensor) -> bool
9+
return x.ndim >= 2
10+
11+
12+
@torch.jit.export
13+
def _get_image_size(img):
14+
# type: (Tensor) -> List[int]
15+
if _is_tensor_a_torch_image(img):
16+
return [img.shape[-1], img.shape[-2]]
17+
raise TypeError("Unexpected type {}".format(type(img)))
818

919

1020
def vflip(img):

0 commit comments

Comments
 (0)