4
4
5
5
import PIL .Image
6
6
import torch
7
- from torch .nn .functional import interpolate
7
+ from torch .nn .functional import interpolate , pad as torch_pad
8
+
8
9
from torchvision .prototype import features
9
10
from torchvision .transforms import functional_pil as _FP , functional_tensor as _FT
10
11
from torchvision .transforms .functional import (
15
16
pil_to_tensor ,
16
17
to_pil_image ,
17
18
)
18
- from torchvision .transforms .functional_tensor import _parse_pad_padding
19
19
20
20
from ._meta import convert_format_bounding_box , get_spatial_size_image_pil
21
21
@@ -663,7 +663,28 @@ def rotate(
663
663
return rotate_image_pil (inpt , angle , interpolation = interpolation , expand = expand , fill = fill , center = center )
664
664
665
665
666
- pad_image_pil = _FP .pad
666
+ def _parse_pad_padding (padding : Union [int , List [int ]]) -> List [int ]:
667
+ if isinstance (padding , int ):
668
+ pad_left = pad_right = pad_top = pad_bottom = padding
669
+ elif isinstance (padding , (tuple , list )):
670
+ if len (padding ) == 1 :
671
+ pad_left = pad_right = pad_top = pad_bottom = padding [0 ]
672
+ elif len (padding ) == 2 :
673
+ pad_left = pad_right = padding [0 ]
674
+ pad_top = pad_bottom = padding [1 ]
675
+ elif len (padding ) == 4 :
676
+ pad_left = padding [0 ]
677
+ pad_top = padding [1 ]
678
+ pad_right = padding [2 ]
679
+ pad_bottom = padding [3 ]
680
+ else :
681
+ raise ValueError (
682
+ f"Padding must be an int or a 1, 2, or 4 element tuple, not a { len (padding )} element tuple"
683
+ )
684
+ else :
685
+ raise TypeError (f"`padding` should be an integer or tuple or list of integers, but got { padding } " )
686
+
687
+ return [pad_left , pad_right , pad_top , pad_bottom ]
667
688
668
689
669
690
def pad_image_tensor (
@@ -672,50 +693,86 @@ def pad_image_tensor(
672
693
fill : features .FillTypeJIT = None ,
673
694
padding_mode : str = "constant" ,
674
695
) -> torch .Tensor :
696
+ # Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
697
+ # `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
698
+ # internally.
699
+ torch_padding = _parse_pad_padding (padding )
700
+
701
+ if padding_mode not in ["constant" , "edge" , "reflect" , "symmetric" ]:
702
+ raise ValueError (
703
+ f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
704
+ f"but got `'{ padding_mode } '`."
705
+ )
706
+
675
707
if fill is None :
676
- # This is a JIT workaround
677
- return _pad_with_scalar_fill (image , padding , fill = None , padding_mode = padding_mode )
678
- elif isinstance (fill , (int , float )) or len (fill ) == 1 :
679
- fill_number = fill [0 ] if isinstance (fill , list ) else fill
680
- return _pad_with_scalar_fill (image , padding , fill = fill_number , padding_mode = padding_mode )
708
+ fill = 0
709
+
710
+ if isinstance (fill , (int , float )):
711
+ return _pad_with_scalar_fill (image , torch_padding , fill = fill , padding_mode = padding_mode )
712
+ elif len (fill ) == 1 :
713
+ return _pad_with_scalar_fill (image , torch_padding , fill = fill [0 ], padding_mode = padding_mode )
681
714
else :
682
- return _pad_with_vector_fill (image , padding , fill = fill , padding_mode = padding_mode )
715
+ return _pad_with_vector_fill (image , torch_padding , fill = fill , padding_mode = padding_mode )
683
716
684
717
685
718
def _pad_with_scalar_fill (
686
719
image : torch .Tensor ,
687
- padding : Union [ int , List [int ] ],
688
- fill : Union [int , float , None ],
689
- padding_mode : str = "constant" ,
720
+ torch_padding : List [int ],
721
+ fill : Union [int , float ],
722
+ padding_mode : str ,
690
723
) -> torch .Tensor :
691
724
shape = image .shape
692
725
num_channels , height , width = shape [- 3 :]
693
726
694
727
if image .numel () > 0 :
695
- image = _FT .pad (
696
- img = image .reshape (- 1 , num_channels , height , width ), padding = padding , fill = fill , padding_mode = padding_mode
697
- )
728
+ image = image .reshape (- 1 , num_channels , height , width )
729
+
730
+ if padding_mode == "edge" :
731
+ # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
732
+ # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
733
+ # name.
734
+ padding_mode = "replicate"
735
+
736
+ if padding_mode == "constant" :
737
+ image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
738
+ elif padding_mode in ("reflect" , "replicate" ):
739
+ # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
740
+ # TODO: See https://github.com/pytorch/pytorch/issues/40763
741
+ dtype = image .dtype
742
+ if not image .is_floating_point ():
743
+ needs_cast = True
744
+ image = image .to (torch .float32 )
745
+ else :
746
+ needs_cast = False
747
+
748
+ image = torch_pad (image , torch_padding , mode = padding_mode )
749
+
750
+ if needs_cast :
751
+ image = image .to (dtype )
752
+ else : # padding_mode == "symmetric"
753
+ image = _FT ._pad_symmetric (image , torch_padding )
754
+
698
755
new_height , new_width = image .shape [- 2 :]
699
756
else :
700
- left , right , top , bottom = _FT . _parse_pad_padding ( padding )
757
+ left , right , top , bottom = torch_padding
701
758
new_height = height + top + bottom
702
759
new_width = width + left + right
703
760
704
761
return image .reshape (shape [:- 3 ] + (num_channels , new_height , new_width ))
705
762
706
763
707
- # TODO: This should be removed once pytorch pad supports non-scalar padding values
764
+ # TODO: This should be removed once torch_pad supports non-scalar padding values
708
765
def _pad_with_vector_fill (
709
766
image : torch .Tensor ,
710
- padding : Union [ int , List [int ] ],
767
+ torch_padding : List [int ],
711
768
fill : List [float ],
712
- padding_mode : str = "constant" ,
769
+ padding_mode : str ,
713
770
) -> torch .Tensor :
714
771
if padding_mode != "constant" :
715
772
raise ValueError (f"Padding mode '{ padding_mode } ' is not supported if fill is not scalar" )
716
773
717
- output = _pad_with_scalar_fill (image , padding , fill = 0 , padding_mode = "constant" )
718
- left , right , top , bottom = _parse_pad_padding ( padding )
774
+ output = _pad_with_scalar_fill (image , torch_padding , fill = 0 , padding_mode = "constant" )
775
+ left , right , top , bottom = torch_padding
719
776
fill = torch .tensor (fill , dtype = image .dtype , device = image .device ).reshape (- 1 , 1 , 1 )
720
777
721
778
if top > 0 :
@@ -729,6 +786,9 @@ def _pad_with_vector_fill(
729
786
return output
730
787
731
788
789
+ pad_image_pil = _FP .pad
790
+
791
+
732
792
def pad_mask (
733
793
mask : torch .Tensor ,
734
794
padding : Union [int , List [int ]],
0 commit comments