Skip to content

Commit 0d2ad96

Browse files
committed
add batch dimension squashing to some kernels
1 parent 36f52dc commit 0d2ad96

File tree

3 files changed

+95
-5
lines changed

3 files changed

+95
-5
lines changed

torchvision/prototype/transforms/functional/_color.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,25 @@ def adjust_contrast(inpt: features.InputTypeJIT, contrast_factor: float) -> feat
5858

5959

6060
def adjust_sharpness_video(video: torch.Tensor, sharpness_factor: float) -> torch.Tensor:
61-
return adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
61+
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
62+
# https://github.com/pytorch/vision/issues/6670 is resolved.
63+
if video.numel() == 0:
64+
return video
65+
66+
shape = video.shape
67+
68+
if video.ndim > 4:
69+
video = video.view((-1,) + shape[-3:])
70+
needs_unsquash = True
71+
else:
72+
needs_unsquash = False
73+
74+
output = adjust_sharpness_image_tensor(video, sharpness_factor=sharpness_factor)
75+
76+
if needs_unsquash:
77+
output = output.view(shape)
78+
79+
return output
6280

6381

6482
def adjust_sharpness(inpt: features.InputTypeJIT, sharpness_factor: float) -> features.InputTypeJIT:
@@ -160,7 +178,25 @@ def autocontrast(inpt: features.InputTypeJIT) -> features.InputTypeJIT:
160178

161179

162180
def equalize_video(video: torch.Tensor) -> torch.Tensor:
163-
return equalize_image_tensor(video)
181+
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
182+
# https://github.com/pytorch/vision/issues/6670 is resolved.
183+
if video.numel() == 0:
184+
return video
185+
186+
shape = video.shape
187+
188+
if video.ndim > 4:
189+
video = video.view((-1,) + shape[-3:])
190+
needs_unsquash = True
191+
else:
192+
needs_unsquash = False
193+
194+
output = equalize_image_tensor(video)
195+
196+
if needs_unsquash:
197+
output = output.view(shape)
198+
199+
return output
164200

165201

166202
def equalize(inpt: features.InputTypeJIT) -> features.InputTypeJIT:

torchvision/prototype/transforms/functional/_geometry.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1002,7 +1002,25 @@ def perspective_video(
10021002
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
10031003
fill: features.FillTypeJIT = None,
10041004
) -> torch.Tensor:
1005-
return perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
1005+
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
1006+
# https://github.com/pytorch/vision/issues/6670 is resolved.
1007+
if video.numel() == 0:
1008+
return video
1009+
1010+
shape = video.shape
1011+
1012+
if video.ndim > 4:
1013+
video = video.view((-1,) + shape[-3:])
1014+
needs_unsquash = True
1015+
else:
1016+
needs_unsquash = False
1017+
1018+
output = perspective_image_tensor(video, perspective_coeffs, interpolation=interpolation, fill=fill)
1019+
1020+
if needs_unsquash:
1021+
output = output.view(shape)
1022+
1023+
return output
10061024

10071025

10081026
def perspective(
@@ -1105,7 +1123,25 @@ def elastic_video(
11051123
interpolation: InterpolationMode = InterpolationMode.BILINEAR,
11061124
fill: features.FillTypeJIT = None,
11071125
) -> torch.Tensor:
1108-
return elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1126+
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
1127+
# https://github.com/pytorch/vision/issues/6670 is resolved.
1128+
if video.numel() == 0:
1129+
return video
1130+
1131+
shape = video.shape
1132+
1133+
if video.ndim > 4:
1134+
video = video.view((-1,) + shape[-3:])
1135+
needs_unsquash = True
1136+
else:
1137+
needs_unsquash = False
1138+
1139+
output = elastic_image_tensor(video, displacement, interpolation=interpolation, fill=fill)
1140+
1141+
if needs_unsquash:
1142+
output = output.view(shape)
1143+
1144+
return output
11091145

11101146

11111147
def elastic(

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,25 @@ def gaussian_blur_image_pil(
7171
def gaussian_blur_video(
7272
video: torch.Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None
7373
) -> torch.Tensor:
74-
return gaussian_blur_image_tensor(video, kernel_size, sigma)
74+
# TODO: this is a temporary workaround until the image kernel supports arbitrary batch sizes. Remove this when
75+
# https://github.com/pytorch/vision/issues/6670 is resolved.
76+
if video.numel() == 0:
77+
return video
78+
79+
shape = video.shape
80+
81+
if video.ndim > 4:
82+
video = video.view((-1,) + shape[-3:])
83+
needs_unsquash = True
84+
else:
85+
needs_unsquash = False
86+
87+
output = gaussian_blur_image_tensor(video, kernel_size, sigma)
88+
89+
if needs_unsquash:
90+
output = output.view(shape)
91+
92+
return output
7593

7694

7795
def gaussian_blur(

0 commit comments

Comments
 (0)