Skip to content

enable arbitrary batch size for all prototype kernels #6726

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 0 additions & 9 deletions test/prototype_transforms_dispatcher_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition):
]


xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)


DISPATCHER_INFOS = [
DispatcherInfo(
F.horizontal_flip,
Expand Down Expand Up @@ -260,7 +254,6 @@ def xfail_all_tests(*, reason, condition):
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
test_marks=[
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
*xfails_degenerate_or_multi_batch_dims,
],
),
DispatcherInfo(
Expand All @@ -271,7 +264,6 @@ def xfail_all_tests(*, reason, condition):
features.Mask: F.elastic_mask,
},
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
test_marks=xfails_degenerate_or_multi_batch_dims,
),
DispatcherInfo(
F.center_crop,
Expand All @@ -294,7 +286,6 @@ def xfail_all_tests(*, reason, condition):
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
*xfails_degenerate_or_multi_batch_dims,
],
),
DispatcherInfo(
Expand Down
11 changes: 0 additions & 11 deletions test/prototype_transforms_kernel_infos.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition):
]


xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests(
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
)


KERNEL_INFOS = []


Expand Down Expand Up @@ -1156,7 +1150,6 @@ def sample_inputs_perspective_video():
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.perspective_bounding_box,
Expand All @@ -1168,7 +1161,6 @@ def sample_inputs_perspective_video():
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
reference_inputs_fn=reference_inputs_perspective_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.perspective_video,
Expand Down Expand Up @@ -1239,7 +1231,6 @@ def sample_inputs_elastic_video():
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_image_tensor,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.elastic_bounding_box,
Expand All @@ -1251,7 +1242,6 @@ def sample_inputs_elastic_video():
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
reference_inputs_fn=reference_inputs_elastic_mask,
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
test_marks=xfails_image_degenerate_or_multi_batch_dims,
),
KernelInfo(
F.elastic_video,
Expand Down Expand Up @@ -1379,7 +1369,6 @@ def sample_inputs_gaussian_blur_video():
test_marks=[
xfail_jit_python_scalar_arg("kernel_size"),
xfail_jit_python_scalar_arg("sigma"),
*xfails_image_degenerate_or_multi_batch_dims,
],
),
KernelInfo(
Expand Down
76 changes: 36 additions & 40 deletions torchvision/prototype/transforms/functional/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,23 @@ def perspective_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
if image.numel() == 0:
return image

shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)

return output


@torch.jit.unused
Expand Down Expand Up @@ -1007,25 +1023,7 @@ def perspective_video(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video

shape = video.shape

if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)

if needs_unsquash:
output = output.view(shape)

return output
return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)


def perspective(
Expand All @@ -1048,7 +1046,23 @@ def elastic_image_tensor(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
if image.numel() == 0:
return image

shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)

if needs_unsquash:
output = output.view(shape)

return output


@torch.jit.unused
Expand Down Expand Up @@ -1128,25 +1142,7 @@ def elastic_video(
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
fill: features.FillTypeJIT = None,
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video

shape = video.shape

if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)

if needs_unsquash:
output = output.view(shape)

return output
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)


def elastic(
Expand Down
38 changes: 18 additions & 20 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,23 @@ def gaussian_blur_image_tensor(
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")

return _FT.gaussian_blur(image, kernel_size, sigma)
if image.numel() == 0:
return image

shape = image.shape

if image.ndim > 4:
image = image.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = _FT.gaussian_blur(image, kernel_size, sigma)

if needs_unsquash:
output = output.view(shape)

return output


@torch.jit.unused
Expand All @@ -71,25 +87,7 @@ def gaussian_blur_image_pil(
def gaussian_blur_video(
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
) -> torch.Tensor:
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
# https://github.com/pytorch/vision/issues/6670 is resolved.
if video.numel() == 0:
return video

shape = video.shape

if video.ndim > 4:
video = video.view((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

output = gaussian_blur_image_tensor(video, kernel_size, sigma)

if needs_unsquash:
output = output.view(shape)

return output
return gaussian_blur_image_tensor(video, kernel_size, sigma)


def gaussian_blur(
Expand Down