Skip to content

Commit b356e8b

Browse files
committed
Improved tests and docs
1 parent 01175db commit b356e8b

File tree

5 files changed

+111
-69
lines changed

5 files changed

+111
-69
lines changed

test/test_transforms_tensor.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,28 @@ def test_pad(self):
101101

102102
def test_crop(self):
103103
fn_kwargs = {"top": 2, "left": 3, "height": 4, "width": 5}
104-
meth_kwargs = {"size": (4, 5), "padding": [4, ], "pad_if_needed": True, }
104+
# Test transforms.RandomCrop with size and padding as tuple
105+
meth_kwargs = {"size": (4, 5), "padding": (4, 4), "pad_if_needed": True, }
105106
self._test_geom_op(
106107
'crop', 'RandomCrop', fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
107108
)
108109

110+
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
111+
# Test torchscript of transforms.RandomCrop with size as int
112+
f = T.RandomCrop(size=5)
113+
scripted_fn = torch.jit.script(f)
114+
scripted_fn(tensor)
115+
116+
# Test torchscript of transforms.RandomCrop with size as [int, ]
117+
f = T.RandomCrop(size=[5, ], padding=[2, ])
118+
scripted_fn = torch.jit.script(f)
119+
scripted_fn(tensor)
120+
121+
# Test torchscript of transforms.RandomCrop with size as list
122+
f = T.RandomCrop(size=[6, 6])
123+
scripted_fn = torch.jit.script(f)
124+
scripted_fn(tensor)
125+
109126
def test_center_crop(self):
110127
fn_kwargs = {"output_size": (4, 5)}
111128
meth_kwargs = {"size": (4, 5), }
@@ -117,6 +134,21 @@ def test_center_crop(self):
117134
self._test_geom_op(
118135
"center_crop", "CenterCrop", fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
119136
)
137+
tensor = torch.randint(0, 255, (3, 10, 10), dtype=torch.uint8)
138+
# Test torchscript of transforms.CenterCrop with size as int
139+
f = T.CenterCrop(size=5)
140+
scripted_fn = torch.jit.script(f)
141+
scripted_fn(tensor)
142+
143+
# Test torchscript of transforms.CenterCrop with size as [int, ]
144+
f = T.CenterCrop(size=[5, ])
145+
scripted_fn = torch.jit.script(f)
146+
scripted_fn(tensor)
147+
148+
# Test torchscript of transforms.CenterCrop with size as tuple
149+
f = T.CenterCrop(size=(6, 6))
150+
scripted_fn = torch.jit.script(f)
151+
scripted_fn(tensor)
120152

121153
def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, meth_kwargs=None):
122154
if fn_kwargs is None:
@@ -146,13 +178,19 @@ def _test_geom_op_list_output(self, func, method, out_length, fn_kwargs=None, me
146178
self.assertEqual(len(output), len(transformed_t_list_script))
147179

148180
def test_five_crop(self):
149-
fn_kwargs = {"size": (5,)}
150-
meth_kwargs = {"size": (5, )}
181+
fn_kwargs = meth_kwargs = {"size": (5,)}
182+
self._test_geom_op_list_output(
183+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
184+
)
185+
fn_kwargs = meth_kwargs = {"size": [5, ]}
186+
self._test_geom_op_list_output(
187+
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
188+
)
189+
fn_kwargs = meth_kwargs = {"size": (4, 5)}
151190
self._test_geom_op_list_output(
152191
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
153192
)
154-
fn_kwargs = {"size": (4, 5)}
155-
meth_kwargs = {"size": (4, 5)}
193+
fn_kwargs = meth_kwargs = {"size": [4, 5]}
156194
self._test_geom_op_list_output(
157195
"five_crop", "FiveCrop", out_length=5, fn_kwargs=fn_kwargs, meth_kwargs=meth_kwargs
158196
)

torchvision/transforms/functional.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,10 @@ def pad(img: Tensor, padding: List[int], fill: int = 0, padding_mode: str = "con
395395

396396

397397
def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
398-
"""Crop the given PIL Image.
398+
"""Crop the given image at specified location and output size.
399+
The image can be a PIL Image or a Tensor, in which case it is expected
400+
to have [..., H, W] shape, where ... means an arbitrary number of leading
401+
dimensions
399402
400403
Args:
401404
img (PIL Image or Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
@@ -416,13 +419,13 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int) -> Tensor:
416419

417420
def center_crop(img: Tensor, output_size: List[int]) -> Tensor:
418421
"""Crops the given image at the center.
419-
The image can be a PIL Image or a torch Tensor, in which case it is expected
422+
The image can be a PIL Image or a Tensor, in which case it is expected
420423
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
421424
422425
Args:
423426
img (PIL Image or Tensor): Image to be cropped.
424427
output_size (sequence or int): (height, width) of the crop box. If int or sequence with single int
425-
it is used for both directions
428+
it is used for both directions.
426429
427430
Returns:
428431
PIL Image or Tensor: Cropped image.
@@ -469,7 +472,7 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
469472

470473

471474
def hflip(img: Tensor) -> Tensor:
472-
"""Horizontally flip the given PIL Image or torch Tensor.
475+
"""Horizontally flip the given PIL Image or Tensor.
473476
474477
Args:
475478
img (PIL Image or Tensor): Image to be flipped. If img
@@ -531,8 +534,7 @@ def _get_perspective_coeffs(startpoints, endpoints):
531534
532535
Args:
533536
List containing [top-left, top-right, bottom-right, bottom-left] of the original image,
534-
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed
535-
image
537+
List containing [top-left, top-right, bottom-right, bottom-left] of the transformed image
536538
Returns:
537539
octuple (a, b, c, d, e, f, g, h) for transforming each pixel.
538540
"""
@@ -577,7 +579,7 @@ def vflip(img: Tensor) -> Tensor:
577579
"""Vertically flip the given PIL Image or torch Tensor.
578580
579581
Args:
580-
img (PIL Image or Torch Tensor): Image to be flipped. If img
582+
img (PIL Image or Tensor): Image to be flipped. If img
581583
is a Tensor, it is expected to be in [..., H, W] format,
582584
where ... means it can have an arbitrary number of trailing
583585
dimensions.
@@ -593,17 +595,18 @@ def vflip(img: Tensor) -> Tensor:
593595

594596
def five_crop(img: Tensor, size: List[int]) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor]:
595597
"""Crop the given image into four corners and the central crop.
596-
The image can be a PIL Image or a torch Tensor, in which case it is expected
598+
The image can be a PIL Image or a Tensor, in which case it is expected
597599
to have [..., H, W] shape, where ... means an arbitrary number of leading dimensions
598600
599601
.. Note::
600602
This transform returns a tuple of images and there may be a
601603
mismatch in the number of inputs and targets your ``Dataset`` returns.
602604
603605
Args:
604-
size (sequence or int): Desired output size of the crop. If size is an
605-
int instead of sequence like (h, w), a square crop (size, size) is
606-
made.
606+
img (PIL Image or Tensor): Image to be cropped.
607+
size (sequence or int): Desired output size of the crop. If size is an
608+
int instead of sequence like (h, w), a square crop (size, size) is
609+
made. If provided a tuple or list of length 1, it will be interpreted as (size[0], size[0]).
607610
608611
Returns:
609612
tuple: tuple (tl, tr, bl, br, center)
@@ -673,13 +676,13 @@ def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
673676
"""Adjust brightness of an Image.
674677
675678
Args:
676-
img (PIL Image or Torch Tensor): Image to be adjusted.
679+
img (PIL Image or Tensor): Image to be adjusted.
677680
brightness_factor (float): How much to adjust the brightness. Can be
678681
any non negative number. 0 gives a black image, 1 gives the
679682
original image while 2 increases the brightness by a factor of 2.
680683
681684
Returns:
682-
PIL Image or Torch Tensor: Brightness adjusted image.
685+
PIL Image or Tensor: Brightness adjusted image.
683686
"""
684687
if not isinstance(img, torch.Tensor):
685688
return F_pil.adjust_brightness(img, brightness_factor)
@@ -691,13 +694,13 @@ def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
691694
"""Adjust contrast of an Image.
692695
693696
Args:
694-
img (PIL Image or Torch Tensor): Image to be adjusted.
697+
img (PIL Image or Tensor): Image to be adjusted.
695698
contrast_factor (float): How much to adjust the contrast. Can be any
696699
non negative number. 0 gives a solid gray image, 1 gives the
697700
original image while 2 increases the contrast by a factor of 2.
698701
699702
Returns:
700-
PIL Image or Torch Tensor: Contrast adjusted image.
703+
PIL Image or Tensor: Contrast adjusted image.
701704
"""
702705
if not isinstance(img, torch.Tensor):
703706
return F_pil.adjust_contrast(img, contrast_factor)
@@ -709,13 +712,13 @@ def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
709712
"""Adjust color saturation of an image.
710713
711714
Args:
712-
img (PIL Image or Torch Tensor): Image to be adjusted.
715+
img (PIL Image or Tensor): Image to be adjusted.
713716
saturation_factor (float): How much to adjust the saturation. 0 will
714717
give a black and white image, 1 will give the original image while
715718
2 will enhance the saturation by a factor of 2.
716719
717720
Returns:
718-
PIL Image or Torch Tensor: Saturation adjusted image.
721+
PIL Image or Tensor: Saturation adjusted image.
719722
"""
720723
if not isinstance(img, torch.Tensor):
721724
return F_pil.adjust_saturation(img, saturation_factor)

torchvision/transforms/functional_pil.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import numbers
2+
from typing import Any, List
23

34
import torch
45
try:
@@ -10,24 +11,22 @@
1011

1112

1213
@torch.jit.unused
13-
def _is_pil_image(img):
14-
# type: (Any) -> bool
14+
def _is_pil_image(img: Any) -> bool:
1515
if accimage is not None:
1616
return isinstance(img, (Image.Image, accimage.Image))
1717
else:
1818
return isinstance(img, Image.Image)
1919

2020

2121
@torch.jit.unused
22-
def _get_image_size(img):
23-
# type: (Any) -> List[int]
22+
def _get_image_size(img: Any) -> List[int]:
2423
if _is_pil_image(img):
2524
return img.size
2625
raise TypeError("Unexpected type {}".format(type(img)))
2726

2827

2928
@torch.jit.unused
30-
def hflip(img):
29+
def hflip(img: Any):
3130
"""Horizontally flip the given PIL Image.
3231
3332
Args:
@@ -43,7 +42,7 @@ def hflip(img):
4342

4443

4544
@torch.jit.unused
46-
def vflip(img):
45+
def vflip(img: Any):
4746
"""Vertically flip the given PIL Image.
4847
4948
Args:

torchvision/transforms/functional_tensor.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int):
4949
"""Crop the given Image Tensor.
5050
5151
Args:
52-
img (Tensor): Image to be cropped in the form [C, H, W]. (0,0) denotes the top left corner of the image.
52+
img (Tensor): Image to be cropped in the form [..., H, W]. (0,0) denotes the top left corner of the image.
5353
top (int): Vertical component of the top left corner of the crop box.
5454
left (int): Horizontal component of the top left corner of the crop box.
5555
height (int): Height of the crop box.
@@ -64,8 +64,7 @@ def crop(img: Tensor, top: int, left: int, height: int, width: int):
6464
return img[..., top:top + height, left:left + width]
6565

6666

67-
def rgb_to_grayscale(img):
68-
# type: (Tensor) -> Tensor
67+
def rgb_to_grayscale(img: Tensor) -> Tensor:
6968
"""Convert the given RGB Image Tensor to Grayscale.
7069
For RGB to Grayscale conversion, ITU-R 601-2 luma transform is performed which
7170
is L = R * 0.2989 + G * 0.5870 + B * 0.1140
@@ -83,8 +82,7 @@ def rgb_to_grayscale(img):
8382
return (0.2989 * img[0] + 0.5870 * img[1] + 0.1140 * img[2]).to(img.dtype)
8483

8584

86-
def adjust_brightness(img, brightness_factor):
87-
# type: (Tensor, float) -> Tensor
85+
def adjust_brightness(img: Tensor, brightness_factor: float) -> Tensor:
8886
"""Adjust brightness of an RGB image.
8987
9088
Args:
@@ -102,8 +100,7 @@ def adjust_brightness(img, brightness_factor):
102100
return _blend(img, torch.zeros_like(img), brightness_factor)
103101

104102

105-
def adjust_contrast(img, contrast_factor):
106-
# type: (Tensor, float) -> Tensor
103+
def adjust_contrast(img: Tensor, contrast_factor: float) -> Tensor:
107104
"""Adjust contrast of an RGB image.
108105
109106
Args:
@@ -171,8 +168,7 @@ def adjust_hue(img, hue_factor):
171168
return img_hue_adj
172169

173170

174-
def adjust_saturation(img, saturation_factor):
175-
# type: (Tensor, float) -> Tensor
171+
def adjust_saturation(img: Tensor, saturation_factor: float) -> Tensor:
176172
"""Adjust color saturation of an RGB image.
177173
178174
Args:
@@ -190,12 +186,11 @@ def adjust_saturation(img, saturation_factor):
190186
return _blend(img, rgb_to_grayscale(img), saturation_factor)
191187

192188

193-
def center_crop(img, output_size):
194-
# type: (Tensor, BroadcastingList2[int]) -> Tensor
189+
def center_crop(img: Tensor, output_size: BroadcastingList2[int]) -> Tensor:
195190
"""Crop the Image Tensor and resize it to desired size.
196191
197192
Args:
198-
img (Tensor): Image to be cropped. (0,0) denotes the top left corner of the image.
193+
img (Tensor): Image to be cropped.
199194
output_size (sequence or int): (height, width) of the crop box. If int,
200195
it is used for both directions
201196
@@ -213,17 +208,17 @@ def center_crop(img, output_size):
213208
return crop(img, crop_top, crop_left, crop_height, crop_width)
214209

215210

216-
def five_crop(img, size):
217-
# type: (Tensor, BroadcastingList2[int]) -> List[Tensor]
211+
def five_crop(img: Tensor, size: BroadcastingList2[int]) -> List[Tensor]:
218212
"""Crop the given Image Tensor into four corners and the central crop.
219213
.. Note::
220214
This transform returns a List of Tensors and there may be a
221215
mismatch in the number of inputs and targets your ``Dataset`` returns.
222216
223217
Args:
224-
size (sequence or int): Desired output size of the crop. If size is an
225-
int instead of sequence like (h, w), a square crop (size, size) is
226-
made.
218+
img (Tensor): Image to be cropped.
219+
size (sequence or int): Desired output size of the crop. If size is an
220+
int instead of sequence like (h, w), a square crop (size, size) is
221+
made.
227222
228223
Returns:
229224
List: List (tl, tr, bl, br, center)
@@ -249,19 +244,20 @@ def five_crop(img, size):
249244
return [tl, tr, bl, br, center]
250245

251246

252-
def ten_crop(img, size, vertical_flip=False):
253-
# type: (Tensor, BroadcastingList2[int], bool) -> List[Tensor]
247+
def ten_crop(img: Tensor, size: BroadcastingList2[int], vertical_flip: bool = False) -> List[Tensor]:
254248
"""Crop the given Image Tensor into four corners and the central crop plus the
255249
flipped version of these (horizontal flipping is used by default).
250+
256251
.. Note::
257252
This transform returns a List of images and there may be a
258253
mismatch in the number of inputs and targets your ``Dataset`` returns.
259254
260255
Args:
261-
size (sequence or int): Desired output size of the crop. If size is an
256+
img (Tensor): Image to be cropped.
257+
size (sequence or int): Desired output size of the crop. If size is an
262258
int instead of sequence like (h, w), a square crop (size, size) is
263259
made.
264-
vertical_flip (bool): Use vertical flipping instead of horizontal
260+
vertical_flip (bool): Use vertical flipping instead of horizontal
265261
266262
Returns:
267263
List: List (tl, tr, bl, br, center, tl_flip, tr_flip, bl_flip, br_flip, center_flip)
@@ -284,8 +280,7 @@ def ten_crop(img, size, vertical_flip=False):
284280
return first_five + second_five
285281

286282

287-
def _blend(img1, img2, ratio):
288-
# type: (Tensor, Tensor, float) -> Tensor
283+
def _blend(img1: Tensor, img2: Tensor, ratio: float) -> Tensor:
289284
bound = 1 if img1.dtype in [torch.half, torch.float32, torch.float64] else 255
290285
return (ratio * img1 + (1 - ratio) * img2).clamp(0, bound).to(img1.dtype)
291286

0 commit comments

Comments
 (0)