Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 33 additions & 27 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch
from torch.nn.functional import conv2d, pad as torch_pad
from torchvision.prototype import features
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image


Expand Down Expand Up @@ -68,9 +67,9 @@ def normalize(


def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:
lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma)
lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma)
x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device)
kernel1d = torch.softmax(-x.pow_(2), dim=0)
kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0)
return kernel1d


Expand All @@ -89,54 +88,61 @@ def gaussian_blur_image_tensor(
# TODO: consider deprecating integers from sigma on the future
if isinstance(kernel_size, int):
kernel_size = [kernel_size, kernel_size]
if len(kernel_size) != 2:
elif len(kernel_size) != 2:
raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}")
for ksize in kernel_size:
if ksize % 2 == 0 or ksize < 0:
raise ValueError(f"kernel_size should have odd and positive integers. Got {kernel_size}")

if sigma is None:
sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size]

if sigma is not None and not isinstance(sigma, (int, float, list, tuple)):
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
if isinstance(sigma, (int, float)):
sigma = [float(sigma), float(sigma)]
if isinstance(sigma, (list, tuple)) and len(sigma) == 1:
sigma = [sigma[0], sigma[0]]
if len(sigma) != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}")
else:
if isinstance(sigma, (list, tuple)):
length = len(sigma)
if length == 1:
s = float(sigma[0])
sigma = [s, s]
elif length != 2:
raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}")
elif isinstance(sigma, (int, float)):
s = float(sigma)
sigma = [s, s]
else:
raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}")
for s in sigma:
if s <= 0.0:
raise ValueError(f"sigma should have positive values. Got {sigma}")

if image.numel() == 0:
return image

dtype = image.dtype
shape = image.shape

if image.ndim > 4:
ndim = image.ndim
if ndim == 3:
image = image.unsqueeze(dim=0)
elif ndim > 4:
image = image.reshape((-1,) + shape[-3:])
needs_unsquash = True
else:
needs_unsquash = False

dtype = image.dtype if torch.is_floating_point(image) else torch.float32
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device)
kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1])
fp = torch.is_floating_point(image)
kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device)
kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1])

image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype])
output = image if fp else image.to(dtype=torch.float32)

# padding = (left, right, top, bottom)
padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2]
output = torch_pad(image, padding, mode="reflect")
output = conv2d(output, kernel, groups=output.shape[-3])

output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype)
output = torch_pad(output, padding, mode="reflect")
output = conv2d(output, kernel, groups=shape[-3])

if needs_unsquash:
if ndim == 3:
output = output.squeeze(dim=0)
elif ndim > 4:
output = output.reshape(shape)

if not fp:
output = output.round_().to(dtype=dtype)

return output


Expand Down