Skip to content

Commit 5b15f37

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] [prototype] Gaussian Blur clean up (#6888)
Summary: * Refactor gaussian_blur * Add conditional reshape * Further refactoring * Remove unused import. Reviewed By: datumbox Differential Revision: D41020542 fbshipit-source-id: 72694024272d91818c4154f7b5f7097e6d21154f
1 parent 74e8ea9 commit 5b15f37

File tree

1 file changed

+33
-27
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+33
-27
lines changed

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 33 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import torch
66
from torch.nn.functional import conv2d, pad as torch_pad
77
from torchvision.prototype import features
8-
from torchvision.transforms import functional_tensor as _FT
98
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
109

1110

@@ -68,9 +67,9 @@ def normalize(
6867

6968

7069
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
71-
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
70+
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
7271
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
73-
kernel1d = torch.softmax(-x.pow_(2), dim=0)
72+
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
7473
return kernel1d
7574

7675

@@ -89,54 +88,61 @@ def gaussian_blur_image_tensor(
8988
# TODO: consider deprecating integers from sigma on the future
9089
if isinstance(kernel_size, int):
9190
kernel_size = [kernel_size, kernel_size]
92-
if len(kernel_size) != 2:
91+
elif len(kernel_size) != 2:
9392
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
9493
for ksize in kernel_size:
9594
if ksize % 2 == 0 or ksize < 0:
9695
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")
9796

9897
if sigma is None:
9998
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]
100-
101-
if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
102-
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
103-
if isinstance(sigma, (int, float)):
104-
sigma = [float(sigma), float(sigma)]
105-
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
106-
sigma = [sigma[0], sigma[0]]
107-
if len(sigma) != 2:
108-
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
99+
else:
100+
if isinstance(sigma, (list, tuple)):
101+
length = len(sigma)
102+
if length == 1:
103+
s = float(sigma[0])
104+
sigma = [s, s]
105+
elif length != 2:
106+
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
107+
elif isinstance(sigma, (int, float)):
108+
s = float(sigma)
109+
sigma = [s, s]
110+
else:
111+
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
109112
for s in sigma:
110113
if s <= 0.0:
111114
raise ValueError(f"sigma should have positive values. Got {sigma}")
112115

113116
if image.numel() == 0:
114117
return image
115118

119+
dtype = image.dtype
116120
shape = image.shape
117-
118-
if image.ndim > 4:
121+
ndim = image.ndim
122+
if ndim == 3:
123+
image = image.unsqueeze(dim=0)
124+
elif ndim > 4:
119125
image = image.reshape((-1,) + shape[-3:])
120-
needs_unsquash = True
121-
else:
122-
needs_unsquash = False
123126

124-
dtype = image.dtype if torch.is_floating_point(image) else torch.float32
125-
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
126-
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
127+
fp = torch.is_floating_point(image)
128+
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
129+
kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])
127130

128-
image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
131+
output = image if fp else image.to(dtype=torch.float32)
129132

130133
# padding = (left, right, top, bottom)
131134
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
132-
output = torch_pad(image, padding, mode="reflect")
133-
output = conv2d(output, kernel, groups=output.shape[-3])
134-
135-
output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
135+
output = torch_pad(output, padding, mode="reflect")
136+
output = conv2d(output, kernel, groups=shape[-3])
136137

137-
if needs_unsquash:
138+
if ndim == 3:
139+
output = output.squeeze(dim=0)
140+
elif ndim > 4:
138141
output = output.reshape(shape)
139142

143+
if not fp:
144+
output = output.round_().to(dtype=dtype)
145+
140146
return output
141147

142148

0 commit comments

Comments
 (0)