Skip to content

[CHERRYPICK] PIL fill len 1 seq / float fill for int images #7951

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 16 additions & 10 deletions test/test_transforms_v2_refactored.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,11 +309,12 @@ def adapt_fill(value, *, dtype):
return value

max_value = get_max_value(dtype)
value_type = float if dtype.is_floating_point else int

if isinstance(value, (int, float)):
return type(value)(value * max_value)
return value_type(value * max_value)
elif isinstance(value, (list, tuple)):
return type(value)(type(v)(v * max_value) for v in value)
return type(value)(value_type(v * max_value) for v in value)
else:
raise ValueError(f"fill should be an int or float, or a list or tuple of the former, but got '{value}'.")

Expand Down Expand Up @@ -414,6 +415,10 @@ def affine_bounding_boxes(bounding_boxes):
)


# turns all warnings into errors for this module
pytestmark = pytest.mark.filterwarnings("error")


class TestResize:
INPUT_SIZE = (17, 11)
OUTPUT_SIZES = [17, [17], (17,), [12, 13], (12, 13)]
Expand Down Expand Up @@ -2575,18 +2580,19 @@ def test_functional_image_correctness(self, kwargs):
def test_transform(self, param, value, make_input):
input = make_input(self.INPUT_SIZE)

kwargs = {param: value}
if param == "fill":
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
kwargs["size"] = [s + 4 for s in self.INPUT_SIZE]

if isinstance(input, PIL.Image.Image) and isinstance(value, (tuple, list)) and len(value) == 1:
pytest.xfail("F._pad_image_pil does not support sequences of length 1 for fill.")

if isinstance(input, tv_tensors.Mask) and isinstance(value, (tuple, list)):
pytest.skip("F.pad_mask doesn't support non-scalar fill.")

kwargs = dict(
# 1. size is required
# 2. the fill parameter only has an affect if we need padding
size=[s + 4 for s in self.INPUT_SIZE],
fill=adapt_fill(value, dtype=input.dtype if isinstance(input, torch.Tensor) else torch.uint8),
)
else:
kwargs = {param: value}

check_transform(
transforms.RandomCrop(**kwargs, pad_if_needed=True),
input,
Expand Down
37 changes: 0 additions & 37 deletions test/transforms_v2_dispatcher_infos.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import collections.abc

import pytest
import torchvision.transforms.v2.functional as F
from torchvision import tv_tensors
Expand Down Expand Up @@ -112,32 +110,6 @@ def xfail_jit_python_scalar_arg(name, *, reason=None):
multi_crop_skips.append(skip_dispatch_tv_tensor)


def xfails_pil(reason, *, condition=None):
return [
TestMark(("TestDispatchers", test_name), pytest.mark.xfail(reason=reason), condition=condition)
for test_name in ["test_dispatch_pil", "test_pil_output_type"]
]


def fill_sequence_needs_broadcast(args_kwargs):
(image_loader, *_), kwargs = args_kwargs
try:
fill = kwargs["fill"]
except KeyError:
return False

if not isinstance(fill, collections.abc.Sequence) or len(fill) > 1:
return False

return image_loader.num_channels > 1


xfails_pil_if_fill_sequence_needs_broadcast = xfails_pil(
"PIL kernel doesn't support sequences of length 1 for `fill` if the number of color channels is larger.",
condition=fill_sequence_needs_broadcast,
)


DISPATCHER_INFOS = [
DispatcherInfo(
F.resized_crop,
Expand All @@ -159,14 +131,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._pad_image_pil, kernel_name="pad_image_pil"),
test_marks=[
*xfails_pil(
reason=(
"PIL kernel doesn't support sequences of length 1 for argument `fill` and "
"`padding_mode='constant'`, if the number of color channels is larger."
),
condition=lambda args_kwargs: fill_sequence_needs_broadcast(args_kwargs)
and args_kwargs.kwargs.get("padding_mode", "constant") == "constant",
),
xfail_jit("F.pad only supports vector fills for list of floats", condition=pad_xfail_jit_fill_condition),
xfail_jit_python_scalar_arg("padding"),
],
Expand All @@ -181,7 +145,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
},
pil_kernel_info=PILKernelInfo(F._perspective_image_pil),
test_marks=[
*xfails_pil_if_fill_sequence_needs_broadcast,
xfail_jit_python_scalar_arg("fill"),
],
),
Expand Down
6 changes: 4 additions & 2 deletions torchvision/transforms/_functional_pil.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,11 +264,13 @@ def _parse_fill(
if isinstance(fill, (int, float)) and num_channels > 1:
fill = tuple([fill] * num_channels)
if isinstance(fill, (list, tuple)):
if len(fill) != num_channels:
if len(fill) == 1:
fill = fill * num_channels
elif len(fill) != num_channels:
msg = "The number of elements in 'fill' does not match the number of channels of the image ({} != {})"
raise ValueError(msg.format(len(fill), num_channels))

fill = tuple(fill)
fill = tuple(fill) # type: ignore[arg-type]

if img.mode != "F":
if isinstance(fill, (list, tuple)):
Expand Down
6 changes: 5 additions & 1 deletion torchvision/transforms/v2/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -1235,7 +1235,11 @@ def _pad_with_vector_fill(

output = _pad_with_scalar_fill(image, torch_padding, fill=0, padding_mode="constant")
left, right, top, bottom = torch_padding
fill = torch.tensor(fill, dtype=image.dtype, device=image.device).reshape(-1, 1, 1)

# We are creating the tensor in the autodetected dtype first and convert to the right one after to avoid an implicit
# float -> int conversion. That happens for example for the valid input of a uint8 image with floating point fill
# value.
fill = torch.tensor(fill, device=image.device).to(dtype=image.dtype).reshape(-1, 1, 1)

if top > 0:
output[..., :top, :] = fill
Expand Down