Skip to content

Commit 140a480

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Let Normalize() and RandomPhotometricDistort return datapoints instead of tensors (#7113)
Reviewed By: YosuaMichael Differential Revision: D42706907 fbshipit-source-id: a7b7487ab8563f8a43a0ebb84df19579ccd35fe1
1 parent 94ecbbc commit 140a480

File tree

6 files changed

+22
-31
lines changed

6 files changed

+22
-31
lines changed

test/prototype_transforms_dispatcher_infos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,6 @@ def fill_sequence_needs_broadcast(args_kwargs):
426426
datapoints.Video: F.normalize_video,
427427
},
428428
test_marks=[
429-
skip_dispatch_feature,
430429
xfail_jit_python_scalar_arg("mean"),
431430
xfail_jit_python_scalar_arg("std"),
432431
],

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
import torchvision.prototype.transforms.utils
1515
from common_utils import cache, cpu_and_gpu, needs_cuda, set_rng_seed
16-
from prototype_common_utils import assert_close, make_bounding_boxes, make_image, parametrized_error_message
16+
from prototype_common_utils import assert_close, make_bounding_boxes, parametrized_error_message
1717
from prototype_transforms_dispatcher_infos import DISPATCHER_INFOS
1818
from prototype_transforms_kernel_infos import KERNEL_INFOS
1919
from torch.utils._pytree import tree_map
@@ -1185,18 +1185,6 @@ def test_correctness_gaussian_blur_image_tensor(device, spatial_size, dt, ksize,
11851185
torch.testing.assert_close(out, true_out, rtol=0.0, atol=1.0, msg=f"{ksize}, {sigma}")
11861186

11871187

1188-
def test_normalize_output_type():
1189-
inpt = torch.rand(1, 3, 32, 32)
1190-
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
1191-
assert type(output) is torch.Tensor
1192-
torch.testing.assert_close(inpt - 0.5, output)
1193-
1194-
inpt = make_image(color_space=datapoints.ColorSpace.RGB)
1195-
output = F.normalize(inpt, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0])
1196-
assert type(output) is torch.Tensor
1197-
torch.testing.assert_close(inpt - 0.5, output)
1198-
1199-
12001188
@pytest.mark.parametrize(
12011189
"inpt",
12021190
[

torchvision/prototype/datapoints/_image.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
289289
)
290290
return Image.wrap_like(self, output)
291291

292+
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Image:
293+
output = self._F.normalize_image_tensor(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
294+
return Image.wrap_like(self, output)
295+
292296

293297
ImageType = Union[torch.Tensor, PIL.Image.Image, Image]
294298
ImageTypeJIT = torch.Tensor

torchvision/prototype/datapoints/_video.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,10 @@ def gaussian_blur(self, kernel_size: List[int], sigma: Optional[List[float]] = N
241241
output = self._F.gaussian_blur_video(self.as_subclass(torch.Tensor), kernel_size=kernel_size, sigma=sigma)
242242
return Video.wrap_like(self, output)
243243

244+
def normalize(self, mean: List[float], std: List[float], inplace: bool = False) -> Video:
245+
output = self._F.normalize_video(self.as_subclass(torch.Tensor), mean=mean, std=std, inplace=inplace)
246+
return Video.wrap_like(self, output)
247+
244248

245249
VideoType = Union[torch.Tensor, Video]
246250
VideoTypeJIT = torch.Tensor

torchvision/prototype/transforms/_color.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any:
8282
return output
8383

8484

85+
# TODO: This class seems to be untested
8586
class RandomPhotometricDistort(Transform):
8687
_transformed_types = (
8788
datapoints.Image,
@@ -119,15 +120,14 @@ def _get_params(self, flat_inputs: List[Any]) -> Dict[str, Any]:
119120
def _permute_channels(
120121
self, inpt: Union[datapoints.ImageType, datapoints.VideoType], permutation: torch.Tensor
121122
) -> Union[datapoints.ImageType, datapoints.VideoType]:
122-
if isinstance(inpt, PIL.Image.Image):
123+
124+
orig_inpt = inpt
125+
if isinstance(orig_inpt, PIL.Image.Image):
123126
inpt = F.pil_to_tensor(inpt)
124127

125128
output = inpt[..., permutation, :, :]
126129

127-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
128-
output = inpt.wrap_like(inpt, output, color_space=datapoints.ColorSpace.OTHER) # type: ignore[arg-type]
129-
130-
elif isinstance(inpt, PIL.Image.Image):
130+
if isinstance(orig_inpt, PIL.Image.Image):
131131
output = F.to_image_pil(output)
132132

133133
return output

torchvision/prototype/transforms/functional/_misc.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -60,18 +60,14 @@ def normalize(
6060
) -> torch.Tensor:
6161
if not torch.jit.is_scripting():
6262
_log_api_usage_once(normalize)
63-
64-
if isinstance(inpt, (datapoints.Image, datapoints.Video)):
65-
inpt = inpt.as_subclass(torch.Tensor)
66-
elif not is_simple_tensor(inpt):
67-
raise TypeError(
68-
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, "
69-
f"but got {type(inpt)} instead."
70-
)
71-
72-
# Image or Video type should not be retained after normalization due to unknown data range
73-
# Thus we return Tensor for input Image
74-
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
63+
if torch.jit.is_scripting() or is_simple_tensor(inpt):
64+
return normalize_image_tensor(inpt, mean=mean, std=std, inplace=inplace)
65+
elif isinstance(inpt, (datapoints.Image, datapoints.Video)):
66+
return inpt.normalize(mean=mean, std=std, inplace=inplace)
67+
else:
68+
raise TypeError(
69+
f"Input can either be a plain tensor or an `Image` or `Video` datapoint, " f"but got {type(inpt)} instead."
70+
)
7571

7672

7773
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> torch.Tensor:

0 commit comments

Comments
 (0)