@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
727727 shape = image .shape
728728 num_channels , height , width = shape [- 3 :]
729729
730- if image .numel () > 0 :
731- image = image .reshape (- 1 , num_channels , height , width )
732-
733- if padding_mode == "edge" :
734- # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
735- # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
736- # name.
737- padding_mode = "replicate"
738-
739- if padding_mode == "constant" :
740- image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
741- elif padding_mode in ("reflect" , "replicate" ):
742- # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
743- # TODO: See https://github.com/pytorch/pytorch/issues/40763
744- dtype = image .dtype
745- if not image .is_floating_point ():
746- needs_cast = True
747- image = image .to (torch .float32 )
748- else :
749- needs_cast = False
750-
751- image = torch_pad (image , torch_padding , mode = padding_mode )
752-
753- if needs_cast :
754- image = image .to (dtype )
755- else : # padding_mode == "symmetric"
756- image = _FT ._pad_symmetric (image , torch_padding )
730+ batch_size = 1
731+ for s in shape [:- 3 ]:
732+ batch_size *= s
733+
734+ image = image .reshape (batch_size , num_channels , height , width )
735+
736+ if padding_mode == "edge" :
737+ # Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
738+ # the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
739+ # name.
740+ padding_mode = "replicate"
741+
742+ if padding_mode == "constant" :
743+ image = torch_pad (image , torch_padding , mode = padding_mode , value = float (fill ))
744+ elif padding_mode in ("reflect" , "replicate" ):
745+ # `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
746+ # TODO: See https://github.com/pytorch/pytorch/issues/40763
747+ dtype = image .dtype
748+ if not image .is_floating_point ():
749+ needs_cast = True
750+ image = image .to (torch .float32 )
751+ else :
752+ needs_cast = False
757753
758- new_height , new_width = image .shape [- 2 :]
759- else :
760- left , right , top , bottom = torch_padding
761- new_height = height + top + bottom
762- new_width = width + left + right
754+ image = torch_pad (image , torch_padding , mode = padding_mode )
755+
756+ if needs_cast :
757+ image = image .to (dtype )
758+ else : # padding_mode == "symmetric"
759+ image = _FT ._pad_symmetric (image , torch_padding )
760+
761+ new_height , new_width = image .shape [- 2 :]
763762
764763 return image .reshape (shape [:- 3 ] + (num_channels , new_height , new_width ))
765764
@@ -868,7 +867,24 @@ def pad(
868867 return pad_image_pil (inpt , padding , fill = fill , padding_mode = padding_mode )
869868
870869
871- crop_image_tensor = _FT .crop
870+ def crop_image_tensor (image : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
871+ h , w = image .shape [- 2 :]
872+
873+ right = left + width
874+ bottom = top + height
875+
876+ if left < 0 or top < 0 or right > w or bottom > h :
877+ image = image [..., max (top , 0 ) : bottom , max (left , 0 ) : right ]
878+ torch_padding = [
879+ max (min (right , 0 ) - left , 0 ),
880+ max (right - max (w , left ), 0 ),
881+ max (min (bottom , 0 ) - top , 0 ),
882+ max (bottom - max (h , top ), 0 ),
883+ ]
884+ return _pad_with_scalar_fill (image , torch_padding , fill = 0 , padding_mode = "constant" )
885+ return image [..., top :bottom , left :right ]
886+
887+
872888crop_image_pil = _FP .crop
873889
874890
@@ -893,7 +909,18 @@ def crop_bounding_box(
893909
894910
895911def crop_mask (mask : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
896- return crop_image_tensor (mask , top , left , height , width )
912+ if mask .ndim < 3 :
913+ mask = mask .unsqueeze (0 )
914+ needs_squeeze = True
915+ else :
916+ needs_squeeze = False
917+
918+ output = crop_image_tensor (mask , top , left , height , width )
919+
920+ if needs_squeeze :
921+ output = output .squeeze (0 )
922+
923+ return output
897924
898925
899926def crop_video (video : torch .Tensor , top : int , left : int , height : int , width : int ) -> torch .Tensor :
0 commit comments