Skip to content

Commit 8dc14cf

Browse files
vfdev-5fmassa
andauthored
Unified Pad and F.pad opertion for PIL and Tensor inputs (#2345)
* [WIP] Add Tensor implementation for pad * Unified Pad and F.pad opertion for PIL and Tensor inputs * Added another test and improved docstring * Updates according to the review * Cosmetics and replaced f-string by "".format * Updated docstring - added compatibility support for padding as [value, ] for functional_pil.pad Co-authored-by: Francisco Massa <[email protected]>
1 parent 2dad9c7 commit 8dc14cf

File tree

6 files changed

+293
-104
lines changed

6 files changed

+293
-104
lines changed

test/test_functional_tensor.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
import torch
2-
from torch import Tensor
32
import torchvision.transforms as transforms
43
import torchvision.transforms.functional_tensor as F_t
4+
import torchvision.transforms.functional_pil as F_pil
55
import torchvision.transforms.functional as F
66
import numpy as np
77
import unittest
88
import random
99
import colorsys
10-
from torch.jit.annotations import Optional, List, BroadcastingList2, Tuple
10+
11+
from PIL import Image
1112

1213

1314
class Tester(unittest.TestCase):
1415

16+
def _create_data(self, height=3, width=3, channels=3):
17+
tensor = torch.randint(0, 255, (channels, height, width), dtype=torch.uint8)
18+
pil_img = Image.fromarray(tensor.permute(1, 2, 0).contiguous().numpy())
19+
return tensor, pil_img
20+
21+
def compareTensorToPIL(self, tensor, pil_image, msg=None):
22+
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
23+
self.assertTrue(tensor.equal(pil_tensor), msg)
24+
1525
def test_vflip(self):
1626
script_vflip = torch.jit.script(F_t.vflip)
1727
img_tensor = torch.randn(3, 16, 16)
@@ -234,6 +244,22 @@ def test_ten_crop(self):
234244
for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
235245
self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
236246

247+
def test_pad(self):
248+
script_fn = torch.jit.script(F_t.pad)
249+
tensor, pil_img = self._create_data(7, 8)
250+
for pad in [1, [1, ], [0, 1], (2, 2), [1, 0, 1, 2]]:
251+
padding_mode = "constant"
252+
for fill in [0, 10, 20]:
253+
pad_tensor = F_t.pad(tensor, pad, fill=fill, padding_mode=padding_mode)
254+
pad_pil_img = F_pil.pad(pil_img, pad, fill=fill, padding_mode=padding_mode)
255+
self.compareTensorToPIL(pad_tensor, pad_pil_img, msg="{}, {}".format(pad, fill))
256+
if isinstance(pad, int):
257+
script_pad = [pad, ]
258+
else:
259+
script_pad = pad
260+
pad_tensor_script = script_fn(tensor, script_pad, fill=fill, padding_mode=padding_mode)
261+
self.assertTrue(pad_tensor.equal(pad_tensor_script), msg="{}, {}".format(pad, fill))
262+
237263

238264
if __name__ == '__main__':
239265
unittest.main()

test/test_transforms_tensor.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,38 @@ def compareTensorToPIL(self, tensor, pil_image):
1818
pil_tensor = torch.as_tensor(np.array(pil_image).transpose((2, 0, 1)))
1919
self.assertTrue(tensor.equal(pil_tensor))
2020

21-
def _test_flip(self, func, method):
22-
tensor, pil_img = self._create_data()
23-
flip_tensor = getattr(F, func)(tensor)
24-
flip_pil_img = getattr(F, func)(pil_img)
25-
self.compareTensorToPIL(flip_tensor, flip_pil_img)
21+
def _test_functional_geom_op(self, func, fn_kwargs):
22+
if fn_kwargs is None:
23+
fn_kwargs = {}
24+
tensor, pil_img = self._create_data(height=10, width=10)
25+
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
26+
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
27+
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
28+
29+
def _test_geom_op(self, func, method, fn_kwargs=None, meth_kwargs=None):
30+
if fn_kwargs is None:
31+
fn_kwargs = {}
32+
if meth_kwargs is None:
33+
meth_kwargs = {}
34+
tensor, pil_img = self._create_data(height=10, width=10)
35+
transformed_tensor = getattr(F, func)(tensor, **fn_kwargs)
36+
transformed_pil_img = getattr(F, func)(pil_img, **fn_kwargs)
37+
self.compareTensorToPIL(transformed_tensor, transformed_pil_img)
2638

2739
scripted_fn = torch.jit.script(getattr(F, func))
28-
flip_tensor_script = scripted_fn(tensor)
29-
self.assertTrue(flip_tensor.equal(flip_tensor_script))
40+
transformed_tensor_script = scripted_fn(tensor, **fn_kwargs)
41+
self.assertTrue(transformed_tensor.equal(transformed_tensor_script))
3042

3143
# test for class interface
32-
f = getattr(T, method)()
44+
f = getattr(T, method)(**meth_kwargs)
3345
scripted_fn = torch.jit.script(f)
3446
scripted_fn(tensor)
3547

3648
def test_random_horizontal_flip(self):
37-
self._test_flip('hflip', 'RandomHorizontalFlip')
49+
self._test_geom_op('hflip', 'RandomHorizontalFlip')
3850

3951
def test_random_vertical_flip(self):
40-
self._test_flip('vflip', 'RandomVerticalFlip')
52+
self._test_geom_op('vflip', 'RandomVerticalFlip')
4153

4254
def test_adjustments(self):
4355
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
@@ -65,6 +77,28 @@ def test_adjustments(self):
6577
self.assertLess(max_diff, 5 / 255 + 1e-5)
6678
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
6779

80+
def test_pad(self):
81+
82+
# Test functional.pad (PIL and Tensor) with padding as single int
83+
self._test_functional_geom_op(
84+
"pad", fn_kwargs={"padding": 2, "fill": 0, "padding_mode": "constant"}
85+
)
86+
# Test functional.pad and transforms.Pad with padding as [int, ]
87+
fn_kwargs = meth_kwargs = {"padding": [2, ], "fill": 0, "padding_mode": "constant"}
88+
self._test_geom_op(
89+
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
90+
)
91+
# Test functional.pad and transforms.Pad with padding as list
92+
fn_kwargs = meth_kwargs = {"padding": [4, 4], "fill": 0, "padding_mode": "constant"}
93+
self._test_geom_op(
94+
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
95+
)
96+
# Test functional.pad and transforms.Pad with padding as tuple
97+
fn_kwargs = meth_kwargs = {"padding": (2, 2, 2, 2), "fill": 127, "padding_mode": "constant"}
98+
self._test_geom_op(
99+
"pad", "Pad", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
100+
)
101+
68102

69103
if __name__ == '__main__':
70104
unittest.main()

torchvision/transforms/functional.py

Lines changed: 27 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
1-
import torch
2-
from torch import Tensor
31
import math
2+
import numbers
3+
import warnings
4+
from collections.abc import Iterable
5+
6+
import numpy as np
7+
from numpy import sin, cos, tan
48
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
9+
10+
import torch
11+
from torch import Tensor
12+
from torch.jit.annotations import List
13+
514
try:
615
import accimage
716
except ImportError:
817
accimage = None
9-
import numpy as np
10-
from numpy import sin, cos, tan
11-
import numbers
12-
from collections.abc import Sequence, Iterable
13-
import warnings
1418

1519
from . import functional_pil as F_pil
1620
from . import functional_tensor as F_t
@@ -342,20 +346,24 @@ def scale(*args, **kwargs):
342346
return resize(*args, **kwargs)
343347

344348

345-
def pad(img, padding, fill=0, padding_mode='constant'):
346-
r"""Pad the given PIL Image on all sides with specified padding mode and fill value.
349+
def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "constant") -> Tensor:
350+
r"""Pad the given image on all sides with the given "pad" value.
351+
The image can be a PIL Image or a torch Tensor, in which case it is expected
352+
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
347353
348354
Args:
349-
img (PIL Image): Image to be padded.
350-
padding (int or tuple): Padding on each border. If a single int is provided this
355+
img (PIL Image or Tensor): Image to be padded.
356+
padding (int or tuple or list): Padding on each border. If a single int is provided this
351357
is used to pad all borders. If tuple of length 2 is provided this is the padding
352358
on left/right and top/bottom respectively. If a tuple of length 4 is provided
353-
this is the padding for the left, top, right and bottom borders
354-
respectively.
355-
fill: Pixel fill value for constant fill. Default is 0. If a tuple of
359+
this is the padding for the left, top, right and bottom borders respectively.
360+
In torchscript mode padding as single int is not supported, use a tuple or
361+
list of length 1: ``[padding, ]``.
362+
fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
356363
length 3, it is used to fill R, G, B channels respectively.
357-
This value is only used when the padding_mode is constant
364+
This value is only used when the padding_mode is constant. Only int value is supported for Tensors.
358365
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
366+
Only "constant" is supported for Tensors as of now.
359367
360368
- constant: pads with a constant value, this value is specified with fill
361369
@@ -372,68 +380,12 @@ def pad(img, padding, fill=0, padding_mode='constant'):
372380
will result in [2, 1, 1, 2, 3, 4, 4, 3]
373381
374382
Returns:
375-
PIL Image: Padded image.
383+
PIL Image or Tensor: Padded image.
376384
"""
377-
if not _is_pil_image(img):
378-
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
379-
380-
if not isinstance(padding, (numbers.Number, tuple)):
381-
raise TypeError('Got inappropriate padding arg')
382-
if not isinstance(fill, (numbers.Number, str, tuple)):
383-
raise TypeError('Got inappropriate fill arg')
384-
if not isinstance(padding_mode, str):
385-
raise TypeError('Got inappropriate padding_mode arg')
386-
387-
if isinstance(padding, Sequence) and len(padding) not in [2, 4]:
388-
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " +
389-
"{} element tuple".format(len(padding)))
390-
391-
assert padding_mode in ['constant', 'edge', 'reflect', 'symmetric'], \
392-
'Padding mode should be either constant, edge, reflect or symmetric'
393-
394-
if padding_mode == 'constant':
395-
if isinstance(fill, numbers.Number):
396-
fill = (fill,) * len(img.getbands())
397-
if len(fill) != len(img.getbands()):
398-
raise ValueError('fill should have the same number of elements '
399-
'as the number of channels in the image '
400-
'({}), got {} instead'.format(len(img.getbands()), len(fill)))
401-
if img.mode == 'P':
402-
palette = img.getpalette()
403-
image = ImageOps.expand(img, border=padding, fill=fill)
404-
image.putpalette(palette)
405-
return image
406-
407-
return ImageOps.expand(img, border=padding, fill=fill)
408-
else:
409-
if isinstance(padding, int):
410-
pad_left = pad_right = pad_top = pad_bottom = padding
411-
if isinstance(padding, Sequence) and len(padding) == 2:
412-
pad_left = pad_right = padding[0]
413-
pad_top = pad_bottom = padding[1]
414-
if isinstance(padding, Sequence) and len(padding) == 4:
415-
pad_left = padding[0]
416-
pad_top = padding[1]
417-
pad_right = padding[2]
418-
pad_bottom = padding[3]
419-
420-
if img.mode == 'P':
421-
palette = img.getpalette()
422-
img = np.asarray(img)
423-
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
424-
img = Image.fromarray(img)
425-
img.putpalette(palette)
426-
return img
427-
428-
img = np.asarray(img)
429-
# RGB image
430-
if len(img.shape) == 3:
431-
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
432-
# Grayscale image
433-
if len(img.shape) == 2:
434-
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
385+
if not isinstance(img, torch.Tensor):
386+
return F_pil.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
435387

436-
return Image.fromarray(img)
388+
return F_t.pad(img, padding=padding, fill=fill, padding_mode=padding_mode)
437389

438390

439391
def crop(img, top, left, height, width):

torchvision/transforms/functional_pil.py

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import numbers
2+
13
import torch
24
try:
35
import accimage
46
except ImportError:
57
accimage = None
6-
from PIL import Image, ImageOps, ImageEnhance, __version__ as PILLOW_VERSION
8+
from PIL import Image, ImageOps, ImageEnhance
79
import numpy as np
810

911

@@ -152,3 +154,107 @@ def adjust_hue(img, hue_factor):
152154

153155
img = Image.merge('HSV', (h, s, v)).convert(input_mode)
154156
return img
157+
158+
159+
@torch.jit.unused
160+
def pad(img, padding, fill=0, padding_mode="constant"):
161+
r"""Pad the given PIL.Image on all sides with the given "pad" value.
162+
163+
Args:
164+
img (PIL Image): Image to be padded.
165+
padding (int or tuple or list): Padding on each border. If a single int is provided this
166+
is used to pad all borders. If a tuple or list of length 2 is provided this is the padding
167+
on left/right and top/bottom respectively. If a tuple or list of length 4 is provided
168+
this is the padding for the left, top, right and bottom borders respectively. For compatibility reasons
169+
with ``functional_tensor.pad``, if a tuple or list of length 1 is provided, it is interpreted as
170+
a single int.
171+
fill (int or str or tuple): Pixel fill value for constant fill. Default is 0. If a tuple of
172+
length 3, it is used to fill R, G, B channels respectively.
173+
This value is only used when the padding_mode is constant.
174+
padding_mode: Type of padding. Should be: constant, edge, reflect or symmetric. Default is constant.
175+
176+
- constant: pads with a constant value, this value is specified with fill
177+
178+
- edge: pads with the last value on the edge of the image
179+
180+
- reflect: pads with reflection of image (without repeating the last value on the edge)
181+
182+
padding [1, 2, 3, 4] with 2 elements on both sides in reflect mode
183+
will result in [3, 2, 1, 2, 3, 4, 3, 2]
184+
185+
- symmetric: pads with reflection of image (repeating the last value on the edge)
186+
187+
padding [1, 2, 3, 4] with 2 elements on both sides in symmetric mode
188+
will result in [2, 1, 1, 2, 3, 4, 4, 3]
189+
190+
Returns:
191+
PIL Image: Padded image.
192+
"""
193+
194+
if not _is_pil_image(img):
195+
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
196+
197+
if not isinstance(padding, (numbers.Number, tuple, list)):
198+
raise TypeError("Got inappropriate padding arg")
199+
if not isinstance(fill, (numbers.Number, str, tuple)):
200+
raise TypeError("Got inappropriate fill arg")
201+
if not isinstance(padding_mode, str):
202+
raise TypeError("Got inappropriate padding_mode arg")
203+
204+
if isinstance(padding, list):
205+
padding = tuple(padding)
206+
207+
if isinstance(padding, tuple) and len(padding) not in [1, 2, 4]:
208+
raise ValueError("Padding must be an int or a 1, 2, or 4 element tuple, not a " +
209+
"{} element tuple".format(len(padding)))
210+
211+
if isinstance(padding, tuple) and len(padding) == 1:
212+
# Compatibility with `functional_tensor.pad`
213+
padding = padding[0]
214+
215+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
216+
raise ValueError("Padding mode should be either constant, edge, reflect or symmetric")
217+
218+
if padding_mode == "constant":
219+
if isinstance(fill, numbers.Number):
220+
fill = (fill,) * len(img.getbands())
221+
if len(fill) != len(img.getbands()):
222+
raise ValueError("fill should have the same number of elements "
223+
"as the number of channels in the image "
224+
"({}), got {} instead".format(len(img.getbands()), len(fill)))
225+
if img.mode == "P":
226+
palette = img.getpalette()
227+
image = ImageOps.expand(img, border=padding, fill=fill)
228+
image.putpalette(palette)
229+
return image
230+
231+
return ImageOps.expand(img, border=padding, fill=fill)
232+
else:
233+
if isinstance(padding, int):
234+
pad_left = pad_right = pad_top = pad_bottom = padding
235+
if isinstance(padding, tuple) and len(padding) == 2:
236+
pad_left = pad_right = padding[0]
237+
pad_top = pad_bottom = padding[1]
238+
if isinstance(padding, tuple) and len(padding) == 4:
239+
pad_left = padding[0]
240+
pad_top = padding[1]
241+
pad_right = padding[2]
242+
pad_bottom = padding[3]
243+
244+
if img.mode == 'P':
245+
palette = img.getpalette()
246+
img = np.asarray(img)
247+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
248+
img = Image.fromarray(img)
249+
img.putpalette(palette)
250+
return img
251+
252+
img = np.asarray(img)
253+
# RGB image
254+
if len(img.shape) == 3:
255+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), padding_mode)
256+
# Grayscale image
257+
if len(img.shape) == 2:
258+
img = np.pad(img, ((pad_top, pad_bottom), (pad_left, pad_right)), padding_mode)
259+
260+
return Image.fromarray(img)

0 commit comments

Comments
 (0)