16
16
from . import functional_tensor as F_t
17
17
18
18
19
- def _is_pil_image (img ):
20
- if accimage is not None :
21
- return isinstance (img , (Image .Image , accimage .Image ))
22
- else :
23
- return isinstance (img , Image .Image )
19
+ @torch .jit .export
20
+ def _get_image_size (img ):
21
+ # type: (Tensor) -> List[int]
22
+ if isinstance (img , torch .Tensor ):
23
+ return F_t ._get_image_size (img )
24
+
25
+ return F_pil ._get_image_size (img )
24
26
25
27
28
+ @torch .jit .ignore
26
29
def _is_numpy (img ):
30
+ # type: (Any) -> bool
27
31
return isinstance (img , np .ndarray )
28
32
29
33
34
+ @torch .jit .ignore
30
35
def _is_numpy_image (img ):
36
+ # type: (Any) -> bool
31
37
return img .ndim in {2 , 3 }
32
38
33
39
@@ -42,7 +48,7 @@ def to_tensor(pic):
42
48
Returns:
43
49
Tensor: Converted image.
44
50
"""
45
- if not (_is_pil_image (pic ) or _is_numpy (pic )):
51
+ if not (F_pil . _is_pil_image (pic ) or _is_numpy (pic )):
46
52
raise TypeError ('pic should be PIL Image or ndarray. Got {}' .format (type (pic )))
47
53
48
54
if _is_numpy (pic ) and not _is_numpy_image (pic ):
@@ -97,7 +103,7 @@ def pil_to_tensor(pic):
97
103
Returns:
98
104
Tensor: Converted image.
99
105
"""
100
- if not (_is_pil_image (pic )):
106
+ if not (F_pil . _is_pil_image (pic )):
101
107
raise TypeError ('pic should be PIL Image. Got {}' .format (type (pic )))
102
108
103
109
if accimage is not None and isinstance (pic , accimage .Image ):
@@ -315,7 +321,7 @@ def resize(img, size, interpolation=Image.BILINEAR):
315
321
Returns:
316
322
PIL Image: Resized image.
317
323
"""
318
- if not _is_pil_image (img ):
324
+ if not F_pil . _is_pil_image (img ):
319
325
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
320
326
if not (isinstance (size , int ) or (isinstance (size , Iterable ) and len (size ) == 2 )):
321
327
raise TypeError ('Got inappropriate size arg: {}' .format (size ))
@@ -374,7 +380,7 @@ def pad(img, padding, fill=0, padding_mode='constant'):
374
380
Returns:
375
381
PIL Image: Padded image.
376
382
"""
377
- if not _is_pil_image (img ):
383
+ if not F_pil . _is_pil_image (img ):
378
384
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
379
385
380
386
if not isinstance (padding , (numbers .Number , tuple )):
@@ -436,23 +442,24 @@ def pad(img, padding, fill=0, padding_mode='constant'):
436
442
return Image .fromarray (img )
437
443
438
444
439
- def crop (img , top , left , height , width ) :
445
+ def crop (img : Tensor , top : int , left : int , height : int , width : int ) -> Tensor :
440
446
"""Crop the given PIL Image.
441
447
442
448
Args:
443
- img (PIL Image): Image to be cropped. (0,0) denotes the top left corner of the image.
449
+ img (PIL Image or torch.Tensor ): Image to be cropped. (0,0) denotes the top left corner of the image.
444
450
top (int): Vertical component of the top left corner of the crop box.
445
451
left (int): Horizontal component of the top left corner of the crop box.
446
452
height (int): Height of the crop box.
447
453
width (int): Width of the crop box.
448
454
449
455
Returns:
450
- PIL Image: Cropped image.
456
+ PIL Image or torch.Tensor : Cropped image.
451
457
"""
452
- if not _is_pil_image (img ):
453
- raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
454
458
455
- return img .crop ((left , top , left + width , top + height ))
459
+ if not isinstance (img , torch .Tensor ):
460
+ return F_pil .crop (img , top , left , height , width )
461
+
462
+ return F_t .crop (img , top , left , height , width )
456
463
457
464
458
465
def center_crop (img , output_size ):
@@ -491,7 +498,7 @@ def resized_crop(img, top, left, height, width, size, interpolation=Image.BILINE
491
498
Returns:
492
499
PIL Image: Cropped image.
493
500
"""
494
- assert _is_pil_image (img ), 'img should be PIL Image'
501
+ assert F_pil . _is_pil_image (img ), 'img should be PIL Image'
495
502
img = crop (img , top , left , height , width )
496
503
img = resize (img , size , interpolation )
497
504
return img
@@ -501,13 +508,13 @@ def hflip(img: Tensor) -> Tensor:
501
508
"""Horizontally flip the given PIL Image or torch Tensor.
502
509
503
510
Args:
504
- img (PIL Image or Torch Tensor): Image to be flipped. If img
511
+ img (PIL Image or torch. Tensor): Image to be flipped. If img
505
512
is a Tensor, it is expected to be in [..., H, W] format,
506
513
where ... means it can have an arbitrary number of trailing
507
514
dimensions.
508
515
509
516
Returns:
510
- PIL Image: Horizontally flipped image.
517
+ PIL Image or torch.Tensor : Horizontally flipped image.
511
518
"""
512
519
if not isinstance (img , torch .Tensor ):
513
520
return F_pil .hflip (img )
@@ -593,7 +600,7 @@ def perspective(img, startpoints, endpoints, interpolation=Image.BICUBIC, fill=N
593
600
PIL Image: Perspectively transformed Image.
594
601
"""
595
602
596
- if not _is_pil_image (img ):
603
+ if not F_pil . _is_pil_image (img ):
597
604
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
598
605
599
606
opts = _parse_fill (fill , img , '5.0.0' )
@@ -797,7 +804,7 @@ def adjust_gamma(img, gamma, gain=1):
797
804
while gamma smaller than 1 make dark regions lighter.
798
805
gain (float): The constant multiplier.
799
806
"""
800
- if not _is_pil_image (img ):
807
+ if not F_pil . _is_pil_image (img ):
801
808
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
802
809
803
810
if gamma < 0 :
@@ -837,7 +844,7 @@ def rotate(img, angle, resample=False, expand=False, center=None, fill=None):
837
844
.. _filters: https://pillow.readthedocs.io/en/latest/handbook/concepts.html#filters
838
845
839
846
"""
840
- if not _is_pil_image (img ):
847
+ if not F_pil . _is_pil_image (img ):
841
848
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
842
849
843
850
opts = _parse_fill (fill , img , '5.2.0' )
@@ -918,7 +925,7 @@ def affine(img, angle, translate, scale, shear, resample=0, fillcolor=None):
918
925
If omitted, or if the image has mode "1" or "P", it is set to ``PIL.Image.NEAREST``.
919
926
fillcolor (int): Optional fill color for the area outside the transform in the output image. (Pillow>=5.0.0)
920
927
"""
921
- if not _is_pil_image (img ):
928
+ if not F_pil . _is_pil_image (img ):
922
929
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
923
930
924
931
assert isinstance (translate , (tuple , list )) and len (translate ) == 2 , \
@@ -945,7 +952,7 @@ def to_grayscale(img, num_output_channels=1):
945
952
946
953
if num_output_channels = 3 : returned image is 3 channel with r = g = b
947
954
"""
948
- if not _is_pil_image (img ):
955
+ if not F_pil . _is_pil_image (img ):
949
956
raise TypeError ('img should be PIL Image. Got {}' .format (type (img )))
950
957
951
958
if num_output_channels == 1 :
0 commit comments