Skip to content
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,39 @@
from torchvision.transforms import functional_tensor as _FT
from torchvision.transforms.functional import pil_to_tensor, to_pil_image

normalize_image_tensor = _FT.normalize

def normalize_image_tensor(
image: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False
) -> torch.Tensor:
if not isinstance(image, torch.Tensor):
raise TypeError("Input img should be Tensor image")

if not image.is_floating_point():
raise TypeError(f"Input tensor should be a float tensor. Got {image.dtype}.")

if image.ndim < 3:
raise ValueError(
f"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = {image.size()}"
)

if (isinstance(std, (tuple, list)) and not all(std)) or std == 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need the first part of the check? What input would fail isinstance(std, (tuple, list))? Do we actually allow scalars here? Otherwise, this should be sufficient

Suggested change
if (isinstance(std, (tuple, list)) and not all(std)) or std == 0:
if not all(std):

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We actually allow scalars. It's not visible due to the JIT-script types but if you pass mean=0.5, std=0.5 it works. So I'm keeping this for BC and provide separate benchmarks.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh 🙄 We need to update the tests since they currently don't check scalars:

_NORMALIZE_MEANS_STDS = [
((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]),
]
def sample_inputs_normalize_image_tensor():
for image_loader, (mean, std) in itertools.product(
make_image_loaders(sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[torch.float32]),
_NORMALIZE_MEANS_STDS,
):
yield ArgsKwargs(image_loader, mean=mean, std=std)

Will send a PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. I also had to rewrite the check because JIT couldn't understand the assertions were correct in one line... This version seems to pass. I've updated the benchmarks and we are still good.

raise ValueError(f"std evaluated to zero after conversion to {image.dtype}, leading to division by zero.")

dtype = image.dtype
device = image.device
mean = torch.as_tensor(mean, dtype=dtype, device=device)
std = torch.as_tensor(std, dtype=dtype, device=device)
if mean.ndim == 1:
mean = mean.view(-1, 1, 1)
if std.ndim == 1:
std = std.view(-1, 1, 1)
Comment on lines +36 to +39
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also looking into this earlier and one thing I asked myself, is when would this branch not trigger? The tensor should always have one dimensions unless we allow scalars. See above for that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is purely for broadcasting in case someone passes lists, not scalars. Aka [0.5, 0.5, 0.5]. This is needed else, the following div/sub fails.


if inplace:
image = image.sub_(mean)
else:
image = image.sub(mean)

return image.div_(std)


def normalize_video(video: torch.Tensor, mean: List[float], std: List[float], inplace: bool = False) -> torch.Tensor:
Expand Down