diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index 738e369962d..d8bfc7cae1b 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -5,7 +5,6 @@ import torch from torch.nn.functional import conv2d, pad as torch_pad from torchvision.prototype import features -from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image @@ -68,9 +67,9 @@ def normalize( def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor: - lim = (kernel_size - 1) / (2 * math.sqrt(2) * sigma) + lim = (kernel_size - 1) / (2.0 * math.sqrt(2.0) * sigma) x = torch.linspace(-lim, lim, steps=kernel_size, dtype=dtype, device=device) - kernel1d = torch.softmax(-x.pow_(2), dim=0) + kernel1d = torch.softmax(x.pow_(2).neg_(), dim=0) return kernel1d @@ -89,7 +88,7 @@ def gaussian_blur_image_tensor( # TODO: consider deprecating integers from sigma on the future if isinstance(kernel_size, int): kernel_size = [kernel_size, kernel_size] - if len(kernel_size) != 2: + elif len(kernel_size) != 2: raise ValueError(f"If kernel_size is a sequence its length should be 2. Got {len(kernel_size)}") for ksize in kernel_size: if ksize % 2 == 0 or ksize < 0: @@ -97,15 +96,19 @@ def gaussian_blur_image_tensor( if sigma is None: sigma = [ksize * 0.15 + 0.35 for ksize in kernel_size] - - if sigma is not None and not isinstance(sigma, (int, float, list, tuple)): - raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") - if isinstance(sigma, (int, float)): - sigma = [float(sigma), float(sigma)] - if isinstance(sigma, (list, tuple)) and len(sigma) == 1: - sigma = [sigma[0], sigma[0]] - if len(sigma) != 2: - raise ValueError(f"If sigma is a sequence, its length should be 2. Got {len(sigma)}") + else: + if isinstance(sigma, (list, tuple)): + length = len(sigma) + if length == 1: + s = float(sigma[0]) + sigma = [s, s] + elif length != 2: + raise ValueError(f"If sigma is a sequence, its length should be 2. Got {length}") + elif isinstance(sigma, (int, float)): + s = float(sigma) + sigma = [s, s] + else: + raise TypeError(f"sigma should be either float or sequence of floats. Got {type(sigma)}") for s in sigma: if s <= 0.0: raise ValueError(f"sigma should have positive values. Got {sigma}") @@ -113,30 +116,33 @@ def gaussian_blur_image_tensor( if image.numel() == 0: return image + dtype = image.dtype shape = image.shape - - if image.ndim > 4: + ndim = image.ndim + if ndim == 3: + image = image.unsqueeze(dim=0) + elif ndim > 4: image = image.reshape((-1,) + shape[-3:]) - needs_unsquash = True - else: - needs_unsquash = False - dtype = image.dtype if torch.is_floating_point(image) else torch.float32 - kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=image.device) - kernel = kernel.expand(image.shape[-3], 1, kernel.shape[0], kernel.shape[1]) + fp = torch.is_floating_point(image) + kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype if fp else torch.float32, device=image.device) + kernel = kernel.expand(shape[-3], 1, kernel.shape[0], kernel.shape[1]) - image, need_cast, need_squeeze, out_dtype = _FT._cast_squeeze_in(image, [kernel.dtype]) + output = image if fp else image.to(dtype=torch.float32) # padding = (left, right, top, bottom) padding = [kernel_size[0] // 2, kernel_size[0] // 2, kernel_size[1] // 2, kernel_size[1] // 2] - output = torch_pad(image, padding, mode="reflect") - output = conv2d(output, kernel, groups=output.shape[-3]) - - output = _FT._cast_squeeze_out(output, need_cast, need_squeeze, out_dtype) + output = torch_pad(output, padding, mode="reflect") + output = conv2d(output, kernel, groups=shape[-3]) - if needs_unsquash: + if ndim == 3: + output = output.squeeze(dim=0) + elif ndim > 4: output = output.reshape(shape) + if not fp: + output = output.round_().to(dtype=dtype) + return output