diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7c6aee1f376..41c6ceada03 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -338,30 +338,9 @@ def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(normalize) if not isinstance(tensor, torch.Tensor): - raise TypeError(f"Input tensor should be a torch tensor. Got {type(tensor)}.") + raise TypeError(f"img should be Tensor Image. Got {type(tensor)}") - if not tensor.is_floating_point(): - raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.") - - if tensor.ndim < 3: - raise ValueError( - f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}" - ) - - if not inplace: - tensor = tensor.clone() - - dtype = tensor.dtype - mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) - std = torch.as_tensor(std, dtype=dtype, device=tensor.device) - if (std == 0).any(): - raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") - if mean.ndim == 1: - mean = mean.view(-1, 1, 1) - if std.ndim == 1: - std = std.view(-1, 1, 1) - tensor.sub_(mean).div_(std) - return tensor + return F_t.normalize(tensor, mean=mean, std=std, inplace=inplace) def resize( @@ -1281,11 +1260,7 @@ def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool if not isinstance(img, torch.Tensor): raise TypeError(f"img should be Tensor Image. Got {type(img)}") - if not inplace: - img = img.clone() - - img[..., i : i + h, j : j + w] = v - return img + return F_t.erase(img, i, j, h, w, v, inplace=inplace) def gaussian_blur(img: Tensor, kernel_size: List[int], sigma: Optional[List[float]] = None) -> Tensor: diff --git a/torchvision/transforms/functional_tensor.py b/torchvision/transforms/functional_tensor.py index 4d05f83785b..fae681b3aa9 100644 --- a/torchvision/transforms/functional_tensor.py +++ b/torchvision/transforms/functional_tensor.py @@ -918,3 +918,40 @@ def equalize(img: Tensor) -> Tensor: return _equalize_single_image(img) return torch.stack([_equalize_single_image(x) for x in img]) + + +def normalize(tensor: Tensor, mean: List[float], std: List[float], inplace: bool = False) -> Tensor: + _assert_image_tensor(tensor) + + if not tensor.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {tensor.dtype}.") + + if tensor.ndim < 3: + raise ValueError( + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {tensor.size()}" + ) + + if not inplace: + tensor = tensor.clone() + + dtype = tensor.dtype + mean = torch.as_tensor(mean, dtype=dtype, device=tensor.device) + std = torch.as_tensor(std, dtype=dtype, device=tensor.device) + if (std == 0).any(): + raise ValueError(f"std evaluated to zero after conversion to {dtype}, leading to division by zero.") + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + tensor.sub_(mean).div_(std) + return tensor + + +def erase(img: Tensor, i: int, j: int, h: int, w: int, v: Tensor, inplace: bool = False) -> Tensor: + _assert_image_tensor(img) + + if not inplace: + img = img.clone() + + img[..., i : i + h, j : j + w] = v + return img