Skip to content

Commit b8c8954

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] remove unnecessary checks from pad_image_tensor (#6894)
Summary: * remove unnecessary changes from pad_image_tensor * cleanup * fix fill=None workaround * address review comments * remove more xfails Reviewed By: datumbox Differential Revision: D41020544 fbshipit-source-id: d677ea0dd79f8e8055ed7c36a65a0bb980e3b578
1 parent 1e7abc6 commit b8c8954

File tree

3 files changed

+81
-24
lines changed

3 files changed

+81
-24
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
234234
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
235235
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
236236
),
237-
xfail_jit_python_scalar_arg("padding"),
238237
xfail_jit_tuple_instead_of_list("padding"),
239238
xfail_jit_tuple_instead_of_list("fill"),
240239
# TODO: check if this is a regression since it seems that should be supported if `int` is ok

test/prototype_transforms_kernel_infos.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1146,7 +1146,6 @@ def reference_inputs_pad_bounding_box():
11461146
reference_inputs_fn=reference_inputs_pad_image_tensor,
11471147
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
11481148
test_marks=[
1149-
xfail_jit_python_scalar_arg("padding"),
11501149
xfail_jit_tuple_instead_of_list("padding"),
11511150
xfail_jit_tuple_instead_of_list("fill"),
11521151
# TODO: check if this is a regression since it seems that should be supported if `int` is ok
@@ -1159,7 +1158,6 @@ def reference_inputs_pad_bounding_box():
11591158
reference_fn=reference_pad_bounding_box,
11601159
reference_inputs_fn=reference_inputs_pad_bounding_box,
11611160
test_marks=[
1162-
xfail_jit_python_scalar_arg("padding"),
11631161
xfail_jit_tuple_instead_of_list("padding"),
11641162
],
11651163
),

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@
44

55
import PIL.Image
66
import torch
7-
from torch.nn.functional import interpolate
7+
from torch.nn.functional import interpolate, pad as torch_pad
8+
89
from torchvision.prototype import features
910
from torchvision.transforms import functional_pil as _FP, functional_tensor as _FT
1011
from torchvision.transforms.functional import (
@@ -15,7 +16,6 @@
1516
pil_to_tensor,
1617
to_pil_image,
1718
)
18-
from torchvision.transforms.functional_tensor import _parse_pad_padding
1919

2020
from ._meta import convert_format_bounding_box, get_spatial_size_image_pil
2121

@@ -663,7 +663,28 @@ def rotate(
663663
return rotate_image_pil(inpt, angle, interpolation=interpolation, expand=expand, fill=fill, center=center)
664664

665665

666-
pad_image_pil = _FP.pad
666+
def _parse_pad_padding(padding: Union[int, List[int]]) -> List[int]:
667+
if isinstance(padding, int):
668+
pad_left = pad_right = pad_top = pad_bottom = padding
669+
elif isinstance(padding, (tuple, list)):
670+
if len(padding) == 1:
671+
pad_left = pad_right = pad_top = pad_bottom = padding[0]
672+
elif len(padding) == 2:
673+
pad_left = pad_right = padding[0]
674+
pad_top = pad_bottom = padding[1]
675+
elif len(padding) == 4:
676+
pad_left = padding[0]
677+
pad_top = padding[1]
678+
pad_right = padding[2]
679+
pad_bottom = padding[3]
680+
else:
681+
raise ValueError(
682+
f"Padding must be an int or a 1, 2, or 4 element tuple, not a {len(padding)} element tuple"
683+
)
684+
else:
685+
raise TypeError(f"`padding` should be an integer or tuple or list of integers, but got {padding}")
686+
687+
return [pad_left, pad_right, pad_top, pad_bottom]
667688

668689

669690
def pad_image_tensor(
@@ -672,50 +693,86 @@ def pad_image_tensor(
672693
fill: features.FillTypeJIT = None,
673694
padding_mode: str = "constant",
674695
) -> torch.Tensor:
696+
# Be aware that while `padding` has order `[left, top, right, bottom]` has order, `torch_padding` uses
697+
# `[left, right, top, bottom]`. This stems from the fact that we align our API with PIL, but need to use `torch_pad`
698+
# internally.
699+
torch_padding = _parse_pad_padding(padding)
700+
701+
if padding_mode not in ["constant", "edge", "reflect", "symmetric"]:
702+
raise ValueError(
703+
f"`padding_mode` should be either `'constant'`, `'edge'`, `'reflect'` or `'symmetric'`, "
704+
f"but got `'{padding_mode}'`."
705+
)
706+
675707
if fill is None:
676-
# This is a JIT workaround
677-
return _pad_with_scalar_fill(image, padding, fill=None, padding_mode=padding_mode)
678-
elif isinstance(fill, (int, float)) or len(fill) == 1:
679-
fill_number = fill[0] if isinstance(fill, list) else fill
680-
return _pad_with_scalar_fill(image, padding, fill=fill_number, padding_mode=padding_mode)
708+
fill = 0
709+
710+
if isinstance(fill, (int, float)):
711+
return _pad_with_scalar_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
712+
elif len(fill) == 1:
713+
return _pad_with_scalar_fill(image, torch_padding, fill=fill[0], padding_mode=padding_mode)
681714
else:
682-
return _pad_with_vector_fill(image, padding, fill=fill, padding_mode=padding_mode)
715+
return _pad_with_vector_fill(image, torch_padding, fill=fill, padding_mode=padding_mode)
683716

684717

685718
def _pad_with_scalar_fill(
686719
image: torch.Tensor,
687-
padding: Union[int, List[int]],
688-
fill: Union[int, float, None],
689-
padding_mode: str = "constant",
720+
torch_padding: List[int],
721+
fill: Union[int, float],
722+
padding_mode: str,
690723
) -> torch.Tensor:
691724
shape = image.shape
692725
num_channels, height, width = shape[-3:]
693726

694727
if image.numel() > 0:
695-
image = _FT.pad(
696-
img=image.reshape(-1, num_channels, height, width), padding=padding, fill=fill, padding_mode=padding_mode
697-
)
728+
image = image.reshape(-1, num_channels, height, width)
729+
730+
if padding_mode == "edge":
731+
# Similar to the padding order, `torch_pad`'s PIL's padding modes don't have the same names. Thus, we map
732+
# the PIL name for the padding mode, which we are also using for our API, to the corresponding `torch_pad`
733+
# name.
734+
padding_mode = "replicate"
735+
736+
if padding_mode == "constant":
737+
image = torch_pad(image, torch_padding, mode=padding_mode, value=float(fill))
738+
elif padding_mode in ("reflect", "replicate"):
739+
# `torch_pad` only supports `"reflect"` or `"replicate"` padding for floating point inputs.
740+
# TODO: See https://github.com/pytorch/pytorch/issues/40763
741+
dtype = image.dtype
742+
if not image.is_floating_point():
743+
needs_cast = True
744+
image = image.to(torch.float32)
745+
else:
746+
needs_cast = False
747+
748+
image = torch_pad(image, torch_padding, mode=padding_mode)
749+
750+
if needs_cast:
751+
image = image.to(dtype)
752+
else: # padding_mode == "symmetric"
753+
image = _FT._pad_symmetric(image, torch_padding)
754+
698755
new_height, new_width = image.shape[-2:]
699756
else:
700-
left, right, top, bottom = _FT._parse_pad_padding(padding)
757+
left, right, top, bottom = torch_padding
701758
new_height = height + top + bottom
702759
new_width = width + left + right
703760

704761
return image.reshape(shape[:-3] + (num_channels, new_height, new_width))
705762

706763

707-
# TODO: This should be removed once pytorch pad supports non-scalar padding values
764+
# TODO: This should be removed once torch_pad supports non-scalar padding values
708765
def _pad_with_vector_fill(
709766
image: torch.Tensor,
710-
padding: Union[int, List[int]],
767+
torch_padding: List[int],
711768
fill: List[float],
712-
padding_mode: str = "constant",
769+
padding_mode: str,
713770
) -> torch.Tensor:
714771
if padding_mode != "constant":
715772
raise ValueError(f"Padding mode '{padding_mode}' is not supported if fill is not scalar")
716773

717-
output = _pad_with_scalar_fill(image, padding, fill=0, padding_mode="constant")
718-
left, right, top, bottom = _parse_pad_padding(padding)
774+
output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
775+
left, right, top, bottom = torch_padding
719776
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)
720777

721778
if top > 0:
@@ -729,6 +786,9 @@ def _pad_with_vector_fill(
729786
return output
730787

731788

789+
pad_image_pil = _FP.pad
790+
791+
732792
def pad_mask(
733793
mask: torch.Tensor,
734794
padding: Union[int, List[int]],

0 commit comments

Comments
 (0)