Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 61 additions & 34 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -727,39 +727,38 @@ def _pad_with_scalar_fill(
shape = image.shape
num_channels, height, width = shape[-3:]

if image.numel() > 0:
image = image.reshape(-1, num_channels, height, width)

if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"

if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False

image = torch_pad(image, torch_padding, mode=padding_mode)

if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)
batch_size = 1
for s in shape[:-3]:
batch_size *= s

image = image.reshape(batch_size, num_channels, height, width)

if padding_mode == "edge":
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
# name.
padding_mode = "replicate"

if padding_mode == "constant":
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
elif padding_mode in ("reflect", "replicate"):
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
# TODO: See https://github.com/pytorch/pytorch/issues/40763
dtype = image.dtype
if not image.is_floating_point():
needs_cast = True
image = image.to(torch.float32)
else:
needs_cast = False

new_height, new_width = image.shape[-2:]
else:
left, right, top, bottom = torch_padding
new_height = height + top + bottom
new_width = width + left + right
image = torch_pad(image, torch_padding, mode=padding_mode)

if needs_cast:
image = image.to(dtype)
else: # padding_mode == "symmetric"
image = _FT._pad_symmetric(image, torch_padding)

new_height, new_width = image.shape[-2:]

return image.reshape(shape[:-3] + (num_channels, new_height, new_width))

Expand Down Expand Up @@ -868,7 +867,24 @@ def pad(
return pad_image_pil(inpt, padding, fill=fill, padding_mode=padding_mode)


crop_image_tensor = _FT.crop
def crop_image_tensor(image: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
h, w = image.shape[-2:]

right = left + width
bottom = top + height

if left < 0 or top < 0 or right > w or bottom > h:
image = image[..., max(top, 0) : bottom, max(left, 0) : right]
torch_padding = [
max(min(right, 0) - left, 0),
max(right - max(w, left), 0),
max(min(bottom, 0) - top, 0),
max(bottom - max(h, top), 0),
]
return _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
return image[..., top:bottom, left:right]


crop_image_pil = _FP.crop


Expand All @@ -893,7 +909,18 @@ def crop_bounding_box(


def crop_mask(mask: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
return crop_image_tensor(mask, top, left, height, width)
if mask.ndim < 3:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need this handling here, because _pad_with_scalar_fill doesn't support 2d images, right? Otherwise, I don't see anything in crop_image_tensor that would require 3 or more dims.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct. The issue come from pad:

  File "./vision/torchvision/prototype/transforms/functional/_geometry.py", line 890, in _pad_with_scalar_fill
    num_channels, height, width = shape[-3:]

mask = mask.unsqueeze(0)
needs_squeeze = True
else:
needs_squeeze = False

output = crop_image_tensor(mask, top, left, height, width)

if needs_squeeze:
output = output.squeeze(0)

return output


def crop_video(video: torch.Tensor, top: int, left: int, height: int, width: int) -> torch.Tensor:
Expand Down