|
8 | 8 | from torchvision.transforms import functional_tensor as _FT
|
9 | 9 | from torchvision.transforms.functional import pil_to_tensor, to_pil_image
|
10 | 10 |
|
11 |
| -normalize_image_tensor = _FT.normalize |
| 11 | + |
| 12 | +def normalize_image_tensor( |
| 13 | + image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False |
| 14 | +) -> torch.Tensor: |
| 15 | + if not image.is_floating_point(): |
| 16 | + raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.") |
| 17 | + |
| 18 | + if image.ndim < 3: |
| 19 | + raise ValueError( |
| 20 | + f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}" |
| 21 | + ) |
| 22 | + |
| 23 | + if isinstance(std, (tuple, list)): |
| 24 | + divzero = not all(std) |
| 25 | + elif isinstance(std, (int, float)): |
| 26 | + divzero = std == 0 |
| 27 | + else: |
| 28 | + divzero = False |
| 29 | + if divzero: |
| 30 | + raise ValueError("std evaluated to zero, leading to division by zero.") |
| 31 | + |
| 32 | + dtype = image.dtype |
| 33 | + device = image.device |
| 34 | + mean = torch.as_tensor(mean, dtype=dtype, device=device) |
| 35 | + std = torch.as_tensor(std, dtype=dtype, device=device) |
| 36 | + if mean.ndim == 1: |
| 37 | + mean = mean.view(-1, 1, 1) |
| 38 | + if std.ndim == 1: |
| 39 | + std = std.view(-1, 1, 1) |
| 40 | + |
| 41 | + if inplace: |
| 42 | + image = image.sub_(mean) |
| 43 | + else: |
| 44 | + image = image.sub(mean) |
| 45 | + |
| 46 | + return image.div_(std) |
12 | 47 |
|
13 | 48 |
|
14 | 49 | def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
|
|
0 commit comments