Skip to content

[proto] Fix kernel passthrough and types of Normalize #6490

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions test/test_prototype_transforms_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -1844,22 +1844,12 @@ def test_correctness_elastic_image_or_mask_tensor(device, fn, make_samples):

def test_midlevel_normalize_output_type():
inpt = torch.rand(1, 3, 32, 32)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)

inpt = make_segmentation_mask()
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.SegmentationMask)
torch.testing.assert_close(inpt, output)

inpt = make_bounding_box(format="XYXY")
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
assert isinstance(output, features.BoundingBox)
torch.testing.assert_close(inpt, output)

inpt = make_image(color_space=features.ColorSpace.RGB)
output = F.normalize(inpt, mean=(0.5, 0.5, 0.5), std=(1.0, 1.0, 1.0))
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
assert isinstance(output, torch.Tensor)
torch.testing.assert_close(inpt - 0.5, output)

Expand Down
12 changes: 8 additions & 4 deletions torchvision/prototype/transforms/_misc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import functools
from typing import Any, Callable, Dict, List, Sequence, Type, Union
from typing import Any, Callable, Dict, Sequence, Type, Union

import PIL.Image

Expand All @@ -10,6 +10,8 @@
from torchvision.prototype.transforms._utils import query_bounding_box
from torchvision.transforms.transforms import _setup_size

from ._utils import is_simple_tensor


class Identity(Transform):
def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
Expand Down Expand Up @@ -91,10 +93,12 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:


class Normalize(Transform):
def __init__(self, mean: List[float], std: List[float]):
_transformed_types = (PIL.Image.Image, features.Image, is_simple_tensor)

def __init__(self, mean: Sequence[float], std: Sequence[float]):
super().__init__()
self.mean = mean
self.std = std
self.mean = list(mean)
self.std = list(std)

def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
return F.normalize(inpt, mean=self.mean, std=self.std)
Expand Down
10 changes: 5 additions & 5 deletions torchvision/prototype/transforms/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
normalize_image_tensor = _FT.normalize


def normalize(inpt: DType, mean: List[float], std: List[float], inplace: bool = False) -> DType:
if isinstance(inpt, features._Feature) and not isinstance(inpt, features.Image):
return inpt
elif isinstance(inpt, PIL.Image.Image):
raise TypeError("Unsupported input type")
def normalize(
inpt: Union[torch.Tensor, features.Image], mean: List[float], std: List[float], inplace: bool = False
) -> DType:
if not isinstance(inpt, torch.Tensor):
raise TypeError(f"img should be Tensor Image. Got {type(inpt)}")
else:
# Image instance after normalization is not Image anymore due to unknown data range
# Thus we return Tensor for input Image
Expand Down