Skip to content

Commit ea84b15

Browse files
Joao Gomesfacebook-github-bot
Joao Gomes
authored andcommitted
[fbsync] Fix bug on prototype pad (#6949)
Reviewed By: YosuaMichael Differential Revision: D41376282 fbshipit-source-id: da222edd80d695d3f52816a36a41667212a1b0a0
1 parent 47d3ee1 commit ea84b15

File tree

1 file changed

+61
-34
lines changed

1 file changed

+61
-34
lines changed

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
727727
shape = image.shape
728728
num_channels, height, width = shape[-3:]
729729

730-
if image.numel() > 0:
731-
image = image.reshape(-1, num_channels, height, width)
732-
733-
if padding_mode == "edge":
734-
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
735-
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
736-
# name.
737-
padding_mode = "replicate"
738-
739-
if padding_mode == "constant":
740-
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
741-
elif padding_mode in ("reflect", "replicate"):
742-
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
743-
# TODO: See https://github.com/pytorch/pytorch/issues/40763
744-
dtype = image.dtype
745-
if not image.is_floating_point():
746-
needs_cast = True
747-
image = image.to(torch.float32)
748-
else:
749-
needs_cast = False
750-
751-
image = torch_pad(image, torch_padding, mode=padding_mode)
752-
753-
if needs_cast:
754-
image = image.to(dtype)
755-
else: # padding_mode == "symmetric"
756-
image = _FT._pad_symmetric(image, torch_padding)
730+
batch_size = 1
731+
for s in shape[:-3]:
732+
batch_size *= s
733+
734+
image = image.reshape(batch_size, num_channels, height, width)
735+
736+
if padding_mode == "edge":
737+
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
738+
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
739+
# name.
740+
padding_mode = "replicate"
741+
742+
if padding_mode == "constant":
743+
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
744+
elif padding_mode in ("reflect", "replicate"):
745+
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
746+
# TODO: See https://github.com/pytorch/pytorch/issues/40763
747+
dtype = image.dtype
748+
if not image.is_floating_point():
749+
needs_cast = True
750+
image = image.to(torch.float32)
751+
else:
752+
needs_cast = False
757753

758-
new_height, new_width = image.shape[-2:]
759-
else:
760-
left, right, top, bottom = torch_padding
761-
new_height = height + top + bottom
762-
new_width = width + left + right
754+
image = torch_pad(image, torch_padding, mode=padding_mode)
755+
756+
if needs_cast:
757+
image = image.to(dtype)
758+
else: # padding_mode == "symmetric"
759+
image = _FT._pad_symmetric(image, torch_padding)
760+
761+
new_height, new_width = image.shape[-2:]
763762

764763
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
765764

@@ -868,7 +867,24 @@ def pad(
868867
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)
869868

870869

871-
crop_image_tensor = _FT.crop
870+
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
871+
h, w = image.shape[-2:]
872+
873+
right = left + width
874+
bottom = top + height
875+
876+
if left < 0 or top < 0 or right > w or bottom > h:
877+
image = image[..., max(top, 0) : bottom, max(left, 0) : right]
878+
torch_padding = [
879+
max(min(right, 0) - left, 0),
880+
max(right - max(w, left), 0),
881+
max(min(bottom, 0) - top, 0),
882+
max(bottom - max(h, top), 0),
883+
]
884+
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
885+
return image[..., top:bottom, left:right]
886+
887+
872888
crop_image_pil = _FP.crop
873889

874890

@@ -893,7 +909,18 @@ def crop_bounding_box(
893909

894910

895911
def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
896-
return crop_image_tensor(mask, top, left, height, width)
912+
if mask.ndim < 3:
913+
mask = mask.unsqueeze(0)
914+
needs_squeeze = True
915+
else:
916+
needs_squeeze = False
917+
918+
output = crop_image_tensor(mask, top, left, height, width)
919+
920+
if needs_squeeze:
921+
output = output.squeeze(0)
922+
923+
return output
897924

898925

899926
def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:

0 commit comments

Comments
 (0)