Skip to content

Commit 248e4c0

Browse files
Yosua Michael Maranathafacebook-github-bot
Yosua Michael Maranatha
authored andcommitted
[fbsync] enable arbitrary batch size for all prototype kernels (#6726)
Summary: * enable arbitrary batch size for all prototype kernels * put back perspective dispatcher Reviewed By: NicolasHug Differential Revision: D40427471 fbshipit-source-id: f8cdfdce28462d72bdb2b92a8606b3eb1ff93d15
1 parent 235fb85 commit 248e4c0

File tree

4 files changed

+54
-80
lines changed

4 files changed

+54
-80
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,6 @@ def xfail_all_tests(*, reason, condition):
138138
]
139139

140140

141-
xfails_degenerate_or_multi_batch_dims = xfail_all_tests(
142-
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
143-
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
144-
)
145-
146-
147141
DISPATCHER_INFOS = [
148142
DispatcherInfo(
149143
F.horizontal_flip,
@@ -260,7 +254,6 @@ def xfail_all_tests(*, reason, condition):
260254
pil_kernel_info=PILKernelInfo(F.perspective_image_pil),
261255
test_marks=[
262256
xfail_dispatch_pil_if_fill_sequence_needs_broadcast,
263-
*xfails_degenerate_or_multi_batch_dims,
264257
],
265258
),
266259
DispatcherInfo(
@@ -271,7 +264,6 @@ def xfail_all_tests(*, reason, condition):
271264
features.Mask: F.elastic_mask,
272265
},
273266
pil_kernel_info=PILKernelInfo(F.elastic_image_pil),
274-
test_marks=xfails_degenerate_or_multi_batch_dims,
275267
),
276268
DispatcherInfo(
277269
F.center_crop,
@@ -294,7 +286,6 @@ def xfail_all_tests(*, reason, condition):
294286
test_marks=[
295287
xfail_jit_python_scalar_arg("kernel_size"),
296288
xfail_jit_python_scalar_arg("sigma"),
297-
*xfails_degenerate_or_multi_batch_dims,
298289
],
299290
),
300291
DispatcherInfo(

test/prototype_transforms_kernel_infos.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,6 @@ def xfail_all_tests(*, reason, condition):
156156
]
157157

158158

159-
xfails_image_degenerate_or_multi_batch_dims = xfail_all_tests(
160-
reason="See https://github.com/pytorch/vision/issues/6670 for details.",
161-
condition=lambda args_kwargs: len(args_kwargs.args[0].shape) > 4 or not all(args_kwargs.args[0].shape[:-3]),
162-
)
163-
164-
165159
KERNEL_INFOS = []
166160

167161

@@ -1156,7 +1150,6 @@ def sample_inputs_perspective_video():
11561150
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
11571151
reference_inputs_fn=reference_inputs_perspective_image_tensor,
11581152
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1159-
test_marks=xfails_image_degenerate_or_multi_batch_dims,
11601153
),
11611154
KernelInfo(
11621155
F.perspective_bounding_box,
@@ -1168,7 +1161,6 @@ def sample_inputs_perspective_video():
11681161
reference_fn=pil_reference_wrapper(F.perspective_image_pil),
11691162
reference_inputs_fn=reference_inputs_perspective_mask,
11701163
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1171-
test_marks=xfails_image_degenerate_or_multi_batch_dims,
11721164
),
11731165
KernelInfo(
11741166
F.perspective_video,
@@ -1239,7 +1231,6 @@ def sample_inputs_elastic_video():
12391231
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
12401232
reference_inputs_fn=reference_inputs_elastic_image_tensor,
12411233
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1242-
test_marks=xfails_image_degenerate_or_multi_batch_dims,
12431234
),
12441235
KernelInfo(
12451236
F.elastic_bounding_box,
@@ -1251,7 +1242,6 @@ def sample_inputs_elastic_video():
12511242
reference_fn=pil_reference_wrapper(F.elastic_image_pil),
12521243
reference_inputs_fn=reference_inputs_elastic_mask,
12531244
closeness_kwargs=DEFAULT_IMAGE_CLOSENESS_KWARGS,
1254-
test_marks=xfails_image_degenerate_or_multi_batch_dims,
12551245
),
12561246
KernelInfo(
12571247
F.elastic_video,
@@ -1379,7 +1369,6 @@ def sample_inputs_gaussian_blur_video():
13791369
test_marks=[
13801370
xfail_jit_python_scalar_arg("kernel_size"),
13811371
xfail_jit_python_scalar_arg("sigma"),
1382-
*xfails_image_degenerate_or_multi_batch_dims,
13831372
],
13841373
),
13851374
KernelInfo(

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -882,7 +882,23 @@ def perspective_image_tensor(
882882
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
883883
fill: features.FillTypeJIT = None,
884884
) -> torch.Tensor:
885-
return _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
885+
if image.numel() == 0:
886+
return image
887+
888+
shape = image.shape
889+
890+
if image.ndim > 4:
891+
image = image.view((-1,) + shape[-3:])
892+
needs_unsquash = True
893+
else:
894+
needs_unsquash = False
895+
896+
output = _FT.perspective(image, perspective_coeffs, interpolation=interpolation.value, fill=fill)
897+
898+
if needs_unsquash:
899+
output = output.view(shape)
900+
901+
return output
886902

887903

888904
@torch.jit.unused
@@ -1007,25 +1023,7 @@ def perspective_video(
10071023
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
10081024
fill: features.FillTypeJIT = None,
10091025
) -> torch.Tensor:
1010-
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
1011-
# https://github.com/pytorch/vision/issues/6670 is resolved.
1012-
if video.numel() == 0:
1013-
return video
1014-
1015-
shape = video.shape
1016-
1017-
if video.ndim > 4:
1018-
video = video.view((-1,) + shape[-3:])
1019-
needs_unsquash = True
1020-
else:
1021-
needs_unsquash = False
1022-
1023-
output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
1024-
1025-
if needs_unsquash:
1026-
output = output.view(shape)
1027-
1028-
return output
1026+
return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
10291027

10301028

10311029
def perspective(
@@ -1048,7 +1046,23 @@ def elastic_image_tensor(
10481046
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
10491047
fill: features.FillTypeJIT = None,
10501048
) -> torch.Tensor:
1051-
return _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
1049+
if image.numel() == 0:
1050+
return image
1051+
1052+
shape = image.shape
1053+
1054+
if image.ndim > 4:
1055+
image = image.view((-1,) + shape[-3:])
1056+
needs_unsquash = True
1057+
else:
1058+
needs_unsquash = False
1059+
1060+
output = _FT.elastic_transform(image, displacement, interpolation=interpolation.value, fill=fill)
1061+
1062+
if needs_unsquash:
1063+
output = output.view(shape)
1064+
1065+
return output
10521066

10531067

10541068
@torch.jit.unused
@@ -1128,25 +1142,7 @@ def elastic_video(
11281142
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
11291143
fill: features.FillTypeJIT = None,
11301144
) -> torch.Tensor:
1131-
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
1132-
# https://github.com/pytorch/vision/issues/6670 is resolved.
1133-
if video.numel() == 0:
1134-
return video
1135-
1136-
shape = video.shape
1137-
1138-
if video.ndim > 4:
1139-
video = video.view((-1,) + shape[-3:])
1140-
needs_unsquash = True
1141-
else:
1142-
needs_unsquash = False
1143-
1144-
output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1145-
1146-
if needs_unsquash:
1147-
output = output.view(shape)
1148-
1149-
return output
1145+
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
11501146

11511147

11521148
def elastic(

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,23 @@ def gaussian_blur_image_tensor(
5656
if s <= 0.0:
5757
raise ValueError(f"sigma should have positive values. Got {sigma}")
5858

59-
return _FT.gaussian_blur(image, kernel_size, sigma)
59+
if image.numel() == 0:
60+
return image
61+
62+
shape = image.shape
63+
64+
if image.ndim > 4:
65+
image = image.view((-1,) + shape[-3:])
66+
needs_unsquash = True
67+
else:
68+
needs_unsquash = False
69+
70+
output = _FT.gaussian_blur(image, kernel_size, sigma)
71+
72+
if needs_unsquash:
73+
output = output.view(shape)
74+
75+
return output
6076

6177

6278
@torch.jit.unused
@@ -71,25 +87,7 @@ def gaussian_blur_image_pil(
7187
def gaussian_blur_video(
7288
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
7389
) -> torch.Tensor:
74-
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
75-
# https://github.com/pytorch/vision/issues/6670 is resolved.
76-
if video.numel() == 0:
77-
return video
78-
79-
shape = video.shape
80-
81-
if video.ndim > 4:
82-
video = video.view((-1,) + shape[-3:])
83-
needs_unsquash = True
84-
else:
85-
needs_unsquash = False
86-
87-
output = gaussian_blur_image_tensor(video, kernel_size, sigma)
88-
89-
if needs_unsquash:
90-
output = output.view(shape)
91-
92-
return output
90+
return gaussian_blur_image_tensor(video, kernel_size, sigma)
9391

9492

9593
def gaussian_blur(

0 commit comments

Comments
 (0)