@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
727
727
shape = image .shape
728
728
num_channels , height , width = shape [- 3 :]
729
729
730
- if image .numel () > 0 :
731
- image = image .reshape (- 1 , num_channels , height , width )
732
-
733
- if padding_mode == "edge" :
734
- # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
735
- # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
736
- # name.
737
- padding_mode = "replicate"
738
-
739
- if padding_mode == "constant" :
740
- image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
741
- elif padding_mode in ("reflect" , "replicate" ):
742
- # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
743
- # TODO: See https://github.com/pytorch/pytorch/issues/40763
744
- dtype = image .dtype
745
- if not image .is_floating_point ():
746
- needs_cast = True
747
- image = image .to (torch .float32 )
748
- else :
749
- needs_cast = False
750
-
751
- image = torch_pad (image , torch_padding , mode = padding_mode )
752
-
753
- if needs_cast :
754
- image = image .to (dtype )
755
- else : # padding_mode == "symmetric"
756
- image = _FT ._pad_symmetric (image , torch_padding )
730
+ batch_size = 1
731
+ for s in shape [:- 3 ]:
732
+ batch_size *= s
733
+
734
+ image = image .reshape (batch_size , num_channels , height , width )
735
+
736
+ if padding_mode == "edge" :
737
+ # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
738
+ # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
739
+ # name.
740
+ padding_mode = "replicate"
741
+
742
+ if padding_mode == "constant" :
743
+ image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
744
+ elif padding_mode in ("reflect" , "replicate" ):
745
+ # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
746
+ # TODO: See https://github.com/pytorch/pytorch/issues/40763
747
+ dtype = image .dtype
748
+ if not image .is_floating_point ():
749
+ needs_cast = True
750
+ image = image .to (torch .float32 )
751
+ else :
752
+ needs_cast = False
757
753
758
- new_height , new_width = image .shape [- 2 :]
759
- else :
760
- left , right , top , bottom = torch_padding
761
- new_height = height + top + bottom
762
- new_width = width + left + right
754
+ image = torch_pad (image , torch_padding , mode = padding_mode )
755
+
756
+ if needs_cast :
757
+ image = image .to (dtype )
758
+ else : # padding_mode == "symmetric"
759
+ image = _FT ._pad_symmetric (image , torch_padding )
760
+
761
+ new_height , new_width = image .shape [- 2 :]
763
762
764
763
return image .reshape (shape [:- 3 ] + (num_channels , new_height , new_width ))
765
764
@@ -868,7 +867,24 @@ def pad(
868
867
return pad_image_pil (inpt , padding , fill = fill , padding_mode = padding_mode )
869
868
870
869
871
- crop_image_tensor = _FT .crop
870
+ def crop_image_tensor (image : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
871
+ h , w = image .shape [- 2 :]
872
+
873
+ right = left + width
874
+ bottom = top + height
875
+
876
+ if left < 0 or top < 0 or right > w or bottom > h :
877
+ image = image [..., max (top , 0 ) : bottom , max (left , 0 ) : right ]
878
+ torch_padding = [
879
+ max (min (right , 0 ) - left , 0 ),
880
+ max (right - max (w , left ), 0 ),
881
+ max (min (bottom , 0 ) - top , 0 ),
882
+ max (bottom - max (h , top ), 0 ),
883
+ ]
884
+ return _pad_with_scalar_fill (image , torch_padding , fill = 0 , padding_mode = "constant" )
885
+ return image [..., top :bottom , left :right ]
886
+
887
+
872
888
crop_image_pil = _FP .crop
873
889
874
890
@@ -893,7 +909,18 @@ def crop_bounding_box(
893
909
894
910
895
911
def crop_mask (mask : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
896
- return crop_image_tensor (mask , top , left , height , width )
912
+ if mask .ndim < 3 :
913
+ mask = mask .unsqueeze (0 )
914
+ needs_squeeze = True
915
+ else :
916
+ needs_squeeze = False
917
+
918
+ output = crop_image_tensor (mask , top , left , height , width )
919
+
920
+ if needs_squeeze :
921
+ output = output .squeeze (0 )
922
+
923
+ return output
897
924
898
925
899
926
def crop_video (video : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
0 commit comments