diff --git a/torchvision/prototype/transforms/functional/_misc.py b/torchvision/prototype/transforms/functional/_misc.py index fa4a6e9be73..3a1d8575cd0 100644 --- a/torchvision/prototype/transforms/functional/_misc.py +++ b/torchvision/prototype/transforms/functional/_misc.py @@ -8,7 +8,42 @@ from torchvision.transforms import functional_tensor as _FT from torchvision.transforms.functional import pil_to_tensor, to_pil_image -normalize_image_tensor = _FT.normalize + +def normalize_image_tensor( + image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False +) -> torch.Tensor: + if not image.is_floating_point(): + raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") + + if image.ndim < 3: + raise ValueError( + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" + ) + + if isinstance(std, (tuple, list)): + divzero = not all(std) + elif isinstance(std, (int, float)): + divzero = std == 0 + else: + divzero = False + if divzero: + raise ValueError("std evaluated to zero, leading to division by zero.") + + dtype = image.dtype + device = image.device + mean = torch.as_tensor(mean, dtype=dtype, device=device) + std = torch.as_tensor(std, dtype=dtype, device=device) + if mean.ndim == 1: + mean = mean.view(-1, 1, 1) + if std.ndim == 1: + std = std.view(-1, 1, 1) + + if inplace: + image = image.sub_(mean) + else: + image = image.sub(mean) + + return image.div_(std) def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor: