diff --git a/test/prototype_transforms_dispatcher_infos.py b/test/prototype_transforms_dispatcher_infos.py index be8bd3002c1..de933c7e3fa 100644 --- a/test/prototype_transforms_dispatcher_infos.py +++ b/test/prototype_transforms_dispatcher_infos.py @@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition): ] -xfails_degenerate_or_multi_batch_dims = xfail_all_tests( - reason="See https://github.com/pytorch/vision/issues/6670 for details.", - condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), -) - - DISPATCHER_INFOS = [ DispatcherInfo( F.horizontal_flip, @@ -260,7 +254,6 @@ def xfail_all_tests(*, reason, condition): pil_kernel_info=PILKernelInfo(F.perspective_image_pil), test_marks=[ xfail_dispatch_pil_if_fill_sequence_needs_broadcast, - *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( @@ -271,7 +264,6 @@ def xfail_all_tests(*, reason, condition): features.Mask: F.elastic_mask, }, pil_kernel_info=PILKernelInfo(F.elastic_image_pil), - test_marks=xfails_degenerate_or_multi_batch_dims, ), DispatcherInfo( F.center_crop, @@ -294,7 +286,6 @@ def xfail_all_tests(*, reason, condition): test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), - *xfails_degenerate_or_multi_batch_dims, ], ), DispatcherInfo( diff --git a/test/prototype_transforms_kernel_infos.py b/test/prototype_transforms_kernel_infos.py index d90d3bf68be..9ebfc7a00d2 100644 --- a/test/prototype_transforms_kernel_infos.py +++ b/test/prototype_transforms_kernel_infos.py @@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition): ] -xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests( - reason="See https://github.com/pytorch/vision/issues/6670 for details.", - condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]), -) - - KERNEL_INFOS = [] @@ -1156,7 +1150,6 @@ def sample_inputs_perspective_video(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.perspective_bounding_box, @@ -1168,7 +1161,6 @@ def sample_inputs_perspective_video(): reference_fn=pil_reference_wrapper(F.perspective_image_pil), reference_inputs_fn=reference_inputs_perspective_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.perspective_video, @@ -1239,7 +1231,6 @@ def sample_inputs_elastic_video(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_image_tensor, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.elastic_bounding_box, @@ -1251,7 +1242,6 @@ def sample_inputs_elastic_video(): reference_fn=pil_reference_wrapper(F.elastic_image_pil), reference_inputs_fn=reference_inputs_elastic_mask, closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS, - test_marks=xfails_image_degenerate_or_multi_batch_dims, ), KernelInfo( F.elastic_video, @@ -1379,7 +1369,6 @@ def sample_inputs_gaussian_blur_video(): test_marks=[ xfail_jit_python_scalar_arg("kernel_size"), xfail_jit_python_scalar_arg("sigma"), - *xfails_image_degenerate_or_multi_batch_dims, ], ), KernelInfo( diff --git a/torchvision/prototype/transforms/functional/_geometry.py b/torchvision/prototype/transforms/functional/_geometry.py index 670b2cb87b8..2c064245e8a 100644 --- a/torchvision/prototype/transforms/functional/_geometry.py +++ b/torchvision/prototype/transforms/functional/_geometry.py @@ -882,7 +882,23 @@ def perspective_image_tensor( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill) + + if needs_unsquash: + output = output.view(shape) + + return output @torch.jit.unused @@ -1007,25 +1023,7 @@ def perspective_video( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video - - shape = video.shape - - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - - output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) - - if needs_unsquash: - output = output.view(shape) - - return output + return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill) def perspective( @@ -1048,7 +1046,23 @@ def elastic_image_tensor( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill) + + if needs_unsquash: + output = output.view(shape) + + return output @torch.jit.unused @@ -1128,25 +1142,7 @@ def elastic_video( interpolation: InterpolationMode = InterpolationMode.BILINEAR, fill: features.FillTypeJIT = None, ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video - - shape = video.shape - - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - - output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) - - if needs_unsquash: - output = output.view(shape) - - return output + return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill) def elastic( diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 7b3773e63a1..79a358b4ed5 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -56,7 +56,23 @@ def gaussian_blur_image_tensor( if s <= 0.0: raise ValueError(f"sigma should have positive values. Got {sigma}") - return _FT.gaussian_blur(image, kernel_size, sigma) + if image.numel() == 0: + return image + + shape = image.shape + + if image.ndim > 4: + image = image.view((-1,) + shape[-3:]) + needs_unsquash = True + else: + needs_unsquash = False + + output = _FT.gaussian_blur(image, kernel_size, sigma) + + if needs_unsquash: + output = output.view(shape) + + return output @torch.jit.unused @@ -71,25 +87,7 @@ def gaussian_blur_image_pil( def gaussian_blur_video( video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None ) -> torch.Tensor: - # TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when - # https://github.com/pytorch/vision/issues/6670 is resolved. - if video.numel() == 0: - return video - - shape = video.shape - - if video.ndim > 4: - video = video.view((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - - output = gaussian_blur_image_tensor(video, kernel_size, sigma) - - if needs_unsquash: - output = output.view(shape) - - return output + return gaussian_blur_image_tensor(video, kernel_size, sigma) def gaussian_blur(