Skip to content

Commit 33316e8

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] [prototype] Speed improvement for normalize op (#6821)
Summary: * Avoid GPU-CPU sync on Normalize * Further optimizations. * Apply code review changes. * Fixing JIT. * linter fix Reviewed By: YosuaMichael Differential Revision: D40722904 fbshipit-source-id: e452d89a42b34be852e3125d25756b3f598e50f4
1 parent a710088 commit 33316e8

File tree

1 file changed

+36
-1
lines changed
  • torchvision/prototype/transforms/functional

1 file changed

+36
-1
lines changed

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,42 @@
88
from torchvision.transforms import functional_tensor as _FT
99
from torchvision.transforms.functional import pil_to_tensor, to_pil_image
1010

11-
normalize_image_tensor = _FT.normalize
11+
12+
def normalize_image_tensor(
13+
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
14+
) -> torch.Tensor:
15+
if not image.is_floating_point():
16+
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")
17+
18+
if image.ndim < 3:
19+
raise ValueError(
20+
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}"
21+
)
22+
23+
if isinstance(std, (tuple, list)):
24+
divzero = not all(std)
25+
elif isinstance(std, (int, float)):
26+
divzero = std == 0
27+
else:
28+
divzero = False
29+
if divzero:
30+
raise ValueError("std evaluated to zero, leading to division by zero.")
31+
32+
dtype = image.dtype
33+
device = image.device
34+
mean = torch.as_tensor(mean, dtype=dtype, device=device)
35+
std = torch.as_tensor(std, dtype=dtype, device=device)
36+
if mean.ndim == 1:
37+
mean = mean.view(-1, 1, 1)
38+
if std.ndim == 1:
39+
std = std.view(-1, 1, 1)
40+
41+
if inplace:
42+
image = image.sub_(mean)
43+
else:
44+
image = image.sub(mean)
45+
46+
return image.div_(std)
1247

1348

1449
def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:

0 commit comments

Comments
 (0)