Skip to content

Commit 08758ca

Browse files
datumboxfacebook-github-bot
authored andcommitted
[fbsync] Add tests and proper support for videos in ConvertImageDtype (#6783)
Summary: * add KernelInfo * split dtype and device consistency tests * add proper support for video * fix tests and add DispatcherInfo * add aliases * cleanup * fix typo Reviewed By: YosuaMichael Differential Revision: D40722908 fbshipit-source-id: 36adda72819a12167ed12d27f6715a46c8ee9b56
1 parent e1a66c2 commit 08758ca

11 files changed

+155
-113
lines changed

test/prototype_common_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
UnsupportedInputs,
2323
)
2424
from torchvision.prototype import features
25-
from torchvision.prototype.transforms.functional import convert_image_dtype, to_image_tensor
25+
from torchvision.prototype.transforms.functional import convert_dtype_image_tensor, to_image_tensor
2626
from torchvision.transforms.functional_tensor import _max_value as get_max_value
2727

2828
__all__ = [
@@ -97,8 +97,8 @@ def _process_inputs(self, actual, expected, *, id, allow_subclasses):
9797
def _equalize_attributes(self, actual, expected):
9898
if actual.dtype != expected.dtype:
9999
dtype = torch.promote_types(actual.dtype, expected.dtype)
100-
actual = convert_image_dtype(actual, dtype)
101-
expected = convert_image_dtype(expected, dtype)
100+
actual = convert_dtype_image_tensor(actual, dtype)
101+
expected = convert_dtype_image_tensor(expected, dtype)
102102

103103
return super()._equalize_attributes(actual, expected)
104104

test/prototype_transforms_dispatcher_infos.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,4 +416,14 @@ def xfail_all_tests(*, reason, condition):
416416
skip_dispatch_feature,
417417
],
418418
),
419+
DispatcherInfo(
420+
F.convert_dtype,
421+
kernels={
422+
features.Image: F.convert_dtype_image_tensor,
423+
features.Video: F.convert_dtype_video,
424+
},
425+
test_marks=[
426+
skip_dispatch_feature,
427+
],
428+
),
419429
]

test/prototype_transforms_kernel_infos.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1979,7 +1979,7 @@ def sample_inputs_normalize_video():
19791979
)
19801980

19811981

1982-
def sample_inputs_convert_image_dtype():
1982+
def sample_inputs_convert_dtype_image_tensor():
19831983
for input_dtype, output_dtype in itertools.product(
19841984
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
19851985
):
@@ -1992,10 +1992,8 @@ def sample_inputs_convert_image_dtype():
19921992
):
19931993
yield ArgsKwargs(image_loader, dtype=output_dtype)
19941994

1995-
yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)
19961995

1997-
1998-
def reference_convert_image_dtype(image, dtype=torch.float):
1996+
def reference_convert_dtype_image_tensor(image, dtype=torch.float):
19991997
input_dtype = image.dtype
20001998
output_dtype = dtype
20011999

@@ -2026,7 +2024,7 @@ def fn(value):
20262024
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
20272025

20282026

2029-
def reference_inputs_convert_image_dtype():
2027+
def reference_inputs_convert_dtype_image_tensor():
20302028
for input_dtype, output_dtype in itertools.product(
20312029
[
20322030
torch.uint8,
@@ -2055,41 +2053,50 @@ def reference_inputs_convert_image_dtype():
20552053
yield ArgsKwargs(image, dtype=output_dtype)
20562054

20572055

2056+
def sample_inputs_convert_dtype_video():
2057+
for video_loader in make_video_loaders(sizes=["random"], num_frames=["random"]):
2058+
yield ArgsKwargs(video_loader)
2059+
2060+
2061+
_common_convert_dtype_marks = [
2062+
TestMark(
2063+
("TestKernels", "test_dtype_and_device_consistency"),
2064+
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
2065+
condition=lambda args_kwargs: args_kwargs.args[0].dtype != args_kwargs.kwargs.get("dtype", torch.float32),
2066+
),
2067+
TestMark(
2068+
("TestKernels", "test_scripted_vs_eager"),
2069+
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %')}:UserWarning"),
2070+
),
2071+
]
2072+
20582073
KERNEL_INFOS.extend(
20592074
[
20602075
KernelInfo(
2061-
F.convert_image_dtype,
2062-
sample_inputs_fn=sample_inputs_convert_image_dtype,
2063-
reference_fn=reference_convert_image_dtype,
2064-
reference_inputs_fn=reference_inputs_convert_image_dtype,
2076+
F.convert_dtype_image_tensor,
2077+
sample_inputs_fn=sample_inputs_convert_dtype_image_tensor,
2078+
reference_fn=reference_convert_dtype_image_tensor,
2079+
reference_inputs_fn=reference_inputs_convert_dtype_image_tensor,
20652080
test_marks=[
2066-
TestMark(
2067-
("TestKernels", "test_scripted_vs_eager"),
2068-
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
2069-
),
2070-
TestMark(
2071-
("TestKernels", "test_dtype_and_device_consistency"),
2072-
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
2073-
condition=lambda args_kwargs: args_kwargs.args[0].dtype
2074-
!= args_kwargs.kwargs.get("dtype", torch.float32),
2075-
),
2081+
*_common_convert_dtype_marks,
20762082
TestMark(
20772083
("TestKernels", "test_against_reference"),
20782084
pytest.mark.xfail(reason="Conversion overflows"),
20792085
condition=lambda args_kwargs: (
20802086
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
20812087
and not args_kwargs.kwargs["dtype"].is_floating_point
20822088
)
2083-
or (
2084-
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
2085-
and args_kwargs.kwargs["dtype"] == torch.int64
2086-
)
20872089
or (
20882090
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
20892091
and args_kwargs.kwargs["dtype"] == torch.float16
20902092
),
20912093
),
20922094
],
20932095
),
2096+
KernelInfo(
2097+
F.convert_dtype_video,
2098+
sample_inputs_fn=sample_inputs_convert_dtype_video,
2099+
test_marks=_common_convert_dtype_marks,
2100+
),
20942101
]
20952102
)

test/test_prototype_transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ class TestSmoke:
9292
transforms.RandomErasing(p=1.0),
9393
transforms.Resize([16, 16]),
9494
transforms.CenterCrop([16, 16]),
95-
transforms.ConvertImageDtype(),
95+
transforms.ConvertDtype(),
9696
transforms.RandomHorizontalFlip(),
9797
transforms.Pad(5),
9898
transforms.RandomZoomOut(),

test/test_prototype_transforms_consistency.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def __init__(
153153
),
154154
),
155155
ConsistencyConfig(
156-
prototype_transforms.ConvertImageDtype,
156+
prototype_transforms.ConvertDtype,
157157
legacy_transforms.ConvertImageDtype,
158158
[
159159
ArgsKwargs(torch.float16),

test/test_prototype_transforms_functional.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,7 @@ def test_dispatch_feature(self, info, args_kwargs, spy_on):
307307
(F.get_image_num_channels, F.get_num_channels),
308308
(F.to_pil_image, F.to_image_pil),
309309
(F.elastic_transform, F.elastic),
310+
(F.convert_image_dtype, F.convert_dtype_image_tensor),
310311
]
311312
],
312313
)

torchvision/prototype/transforms/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
ScaleJitter,
4040
TenCrop,
4141
)
42-
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertImageDtype
42+
from ._meta import ClampBoundingBoxes, ConvertBoundingBoxFormat, ConvertColorSpace, ConvertDtype, ConvertImageDtype
4343
from ._misc import (
4444
GaussianBlur,
4545
Identity,

torchvision/prototype/transforms/_meta.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def _transform(self, inpt: features.BoundingBox, params: Dict[str, Any]) -> feat
2525
return features.BoundingBox.wrap_like(inpt, output, format=params["format"])
2626

2727

28-
class ConvertImageDtype(Transform):
28+
class ConvertDtype(Transform):
2929
_transformed_types = (features.is_simple_tensor, features.Image, features.Video)
3030

3131
def __init__(self, dtype: torch.dtype = torch.float32) -> None:
@@ -35,12 +35,12 @@ def __init__(self, dtype: torch.dtype = torch.float32) -> None:
3535
def _transform(
3636
self, inpt: Union[features.TensorImageType, features.TensorVideoType], params: Dict[str, Any]
3737
) -> Union[features.TensorImageType, features.TensorVideoType]:
38-
# TODO: the `inpt.as_subclass(torch.Tensor)` call can be removed as soon as we have a proper dispatcher that
39-
# handles this. See https://github.com/pytorch/vision/pull/6783 for details.
40-
output = F.convert_image_dtype(inpt.as_subclass(torch.Tensor), dtype=self.dtype)
41-
return (
42-
output if features.is_simple_tensor(inpt) else type(inpt).wrap_like(inpt, output) # type: ignore[attr-defined]
43-
)
38+
return F.convert_dtype(inpt, self.dtype)
39+
40+
41+
# We changed the name to align it with the new naming scheme. Still, `ConvertImageDtype` is
42+
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
43+
ConvertImageDtype = ConvertDtype
4444

4545

4646
class ConvertColorSpace(Transform):

torchvision/prototype/transforms/functional/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
convert_color_space_image_pil,
99
convert_color_space_video,
1010
convert_color_space,
11+
convert_dtype_image_tensor,
12+
convert_dtype,
13+
convert_dtype_video,
14+
convert_image_dtype,
1115
get_dimensions_image_tensor,
1216
get_dimensions_image_pil,
1317
get_dimensions,
@@ -162,7 +166,6 @@
162166
normalize_video,
163167
)
164168
from ._type_conversion import (
165-
convert_image_dtype,
166169
decode_image_with_pil,
167170
decode_video_with_av,
168171
pil_to_tensor,

torchvision/prototype/transforms/functional/_meta.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,3 +285,99 @@ def convert_color_space(
285285
return features.Video.wrap_like(inpt, output, color_space=color_space)
286286
else:
287287
return convert_color_space_image_pil(inpt, color_space)
288+
289+
290+
def _num_value_bits(dtype: torch.dtype) -> int:
291+
if dtype == torch.uint8:
292+
return 8
293+
elif dtype == torch.int8:
294+
return 7
295+
elif dtype == torch.int16:
296+
return 15
297+
elif dtype == torch.int32:
298+
return 31
299+
elif dtype == torch.int64:
300+
return 63
301+
else:
302+
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
303+
304+
305+
def convert_dtype_image_tensor(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
306+
if image.dtype == dtype:
307+
return image
308+
309+
float_input = image.is_floating_point()
310+
if torch.jit.is_scripting():
311+
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
312+
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
313+
else:
314+
float_output = dtype.is_floating_point
315+
316+
if float_input:
317+
# float to float
318+
if float_output:
319+
return image.to(dtype)
320+
321+
# float to int
322+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
323+
image.dtype == torch.float64 and dtype == torch.int64
324+
):
325+
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
326+
327+
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
328+
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
329+
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
330+
# for a detailed analysis.
331+
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
332+
# Instead, we can also multiply by the maximum value plus something close to `1`. See
333+
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
334+
eps = 1e-3
335+
max_value = float(_FT._max_value(dtype))
336+
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
337+
# discrete set `{0, 1}`.
338+
return image.mul(max_value + 1.0 - eps).to(dtype)
339+
else:
340+
# int to float
341+
if float_output:
342+
return image.to(dtype).div_(_FT._max_value(image.dtype))
343+
344+
# int to int
345+
num_value_bits_input = _num_value_bits(image.dtype)
346+
num_value_bits_output = _num_value_bits(dtype)
347+
348+
if num_value_bits_input > num_value_bits_output:
349+
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
350+
else:
351+
# The bitshift kernel is not vectorized
352+
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
353+
# This results in the multiplication actually being faster.
354+
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
355+
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
356+
max_value_input = float(_FT._max_value(dtype))
357+
max_value_output = float(_FT._max_value(image.dtype))
358+
factor = int((max_value_input + 1) // (max_value_output + 1))
359+
return image.to(dtype).mul_(factor)
360+
361+
362+
# We changed the name to align it with the new naming scheme. Still, `convert_image_dtype` is
363+
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
364+
convert_image_dtype = convert_dtype_image_tensor
365+
366+
367+
def convert_dtype_video(video: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
368+
return convert_dtype_image_tensor(video, dtype)
369+
370+
371+
def convert_dtype(
372+
inpt: Union[features.ImageTypeJIT, features.VideoTypeJIT], dtype: torch.dtype = torch.float
373+
) -> torch.Tensor:
374+
if isinstance(inpt, torch.Tensor) and (
375+
torch.jit.is_scripting() or not isinstance(inpt, (features.Image, features.Video))
376+
):
377+
return convert_dtype_image_tensor(inpt, dtype)
378+
elif isinstance(inpt, features.Image):
379+
output = convert_dtype_image_tensor(inpt.as_subclass(torch.Tensor), dtype)
380+
return features.Image.wrap_like(inpt, output)
381+
else: # isinstance(inpt, features.Video):
382+
output = convert_dtype_video(inpt.as_subclass(torch.Tensor), dtype)
383+
return features.Video.wrap_like(inpt, output)

0 commit comments

Comments
 (0)