Skip to content

Commit 35bb0e6

Browse files
Joao Gomesfacebook-github-bot
Joao Gomes
authored andcommitted
[fbsync] improve perf on convert_image_dtype and add tests (#6795)
Summary: * improve perf on convert_image_dtype and add tests * add reference tests * use bitshifts for int to int * revert bitshifts for int to int upscale * fix warning ignore Reviewed By: YosuaMichael Differential Revision: D40588162 fbshipit-source-id: 4f1c564f94f75ff37979c123a416b043b4c9ec14
1 parent a4cd11b commit 35bb0e6

File tree

3 files changed

+226
-15
lines changed

3 files changed

+226
-15
lines changed

test/prototype_transforms_kernel_infos.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import decimal
12
import functools
23
import itertools
34
import math
@@ -21,6 +22,7 @@
2122
mark_framework_limitation,
2223
TestMark,
2324
)
25+
from torch.utils._pytree import tree_map
2426
from torchvision.prototype import features
2527
from torchvision.transforms.functional_tensor import _max_value as get_max_value
2628

@@ -1947,3 +1949,119 @@ def sample_inputs_normalize_video():
19471949
),
19481950
]
19491951
)
1952+
1953+
1954+
def sample_inputs_convert_image_dtype():
1955+
for input_dtype, output_dtype in itertools.product(
1956+
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2
1957+
):
1958+
if input_dtype.is_floating_point and output_dtype == torch.int64:
1959+
# conversion cannot be performed safely
1960+
continue
1961+
1962+
for image_loader in make_image_loaders(
1963+
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype]
1964+
):
1965+
yield ArgsKwargs(image_loader, dtype=output_dtype)
1966+
1967+
yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8)
1968+
1969+
1970+
def reference_convert_image_dtype(image, dtype=torch.float):
1971+
input_dtype = image.dtype
1972+
output_dtype = dtype
1973+
1974+
if output_dtype == input_dtype:
1975+
return image
1976+
1977+
def fn(value):
1978+
if input_dtype.is_floating_point:
1979+
if output_dtype.is_floating_point:
1980+
return value
1981+
else:
1982+
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max)
1983+
else:
1984+
input_max_value = torch.iinfo(input_dtype).max
1985+
1986+
if output_dtype.is_floating_point:
1987+
return float(decimal.Decimal(value) / input_max_value)
1988+
else:
1989+
output_max_value = torch.iinfo(output_dtype).max
1990+
1991+
if input_max_value > output_max_value:
1992+
factor = (input_max_value + 1) // (output_max_value + 1)
1993+
return value // factor
1994+
else:
1995+
factor = (output_max_value + 1) // (input_max_value + 1)
1996+
return value * factor
1997+
1998+
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype)
1999+
2000+
2001+
def reference_inputs_convert_image_dtype():
2002+
for input_dtype, output_dtype in itertools.product(
2003+
[
2004+
torch.uint8,
2005+
torch.int16,
2006+
torch.int32,
2007+
torch.int64,
2008+
torch.float16,
2009+
torch.float32,
2010+
torch.float64,
2011+
torch.bfloat16,
2012+
],
2013+
repeat=2,
2014+
):
2015+
if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or (
2016+
input_dtype == torch.float64 and output_dtype == torch.int64
2017+
):
2018+
continue
2019+
2020+
if input_dtype.is_floating_point:
2021+
data = [0.0, 0.5, 1.0]
2022+
else:
2023+
max_value = torch.iinfo(input_dtype).max
2024+
data = [0, max_value // 2, max_value]
2025+
image = torch.tensor(data, dtype=input_dtype)
2026+
2027+
yield ArgsKwargs(image, dtype=output_dtype)
2028+
2029+
2030+
KERNEL_INFOS.extend(
2031+
[
2032+
KernelInfo(
2033+
F.convert_image_dtype,
2034+
sample_inputs_fn=sample_inputs_convert_image_dtype,
2035+
reference_fn=reference_convert_image_dtype,
2036+
reference_inputs_fn=reference_inputs_convert_image_dtype,
2037+
test_marks=[
2038+
TestMark(
2039+
("TestKernels", "test_scripted_vs_eager"),
2040+
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"),
2041+
),
2042+
TestMark(
2043+
("TestKernels", "test_dtype_and_device_consistency"),
2044+
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"),
2045+
condition=lambda args_kwargs: args_kwargs.args[0].dtype
2046+
!= args_kwargs.kwargs.get("dtype", torch.float32),
2047+
),
2048+
TestMark(
2049+
("TestKernels", "test_against_reference"),
2050+
pytest.mark.xfail(reason="Conversion overflows"),
2051+
condition=lambda args_kwargs: (
2052+
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
2053+
and not args_kwargs.kwargs["dtype"].is_floating_point
2054+
)
2055+
or (
2056+
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16}
2057+
and args_kwargs.kwargs["dtype"] == torch.int64
2058+
)
2059+
or (
2060+
args_kwargs.args[0].dtype in {torch.int32, torch.int64}
2061+
and args_kwargs.kwargs["dtype"] == torch.float16
2062+
),
2063+
),
2064+
],
2065+
),
2066+
]
2067+
)

test/test_prototype_transforms_functional.py

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,20 @@ def script(fn):
2626
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error
2727

2828

29+
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None):
30+
args_kwargs = list(args_kwargs_fn(info))
31+
idx_field_len = len(str(len(args_kwargs)))
32+
return [
33+
pytest.param(
34+
info,
35+
args_kwargs_,
36+
marks=info.get_marks(test_id, args_kwargs_) if test_id else [],
37+
id=f"{info.id}-{idx:0{idx_field_len}}",
38+
)
39+
for idx, args_kwargs_ in enumerate(args_kwargs)
40+
]
41+
42+
2943
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None):
3044
if condition is None:
3145

@@ -49,18 +63,7 @@ def decorator(test_fn):
4963
if not condition(info):
5064
continue
5165

52-
args_kwargs = list(args_kwargs_fn(info))
53-
idx_field_len = len(str(len(args_kwargs)))
54-
55-
for idx, args_kwargs_ in enumerate(args_kwargs):
56-
argvalues.append(
57-
pytest.param(
58-
info,
59-
args_kwargs_,
60-
marks=info.get_marks(test_id, args_kwargs_),
61-
id=f"{info.id}-{idx:0{idx_field_len}}",
62-
)
63-
)
66+
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id))
6467

6568
return pytest.mark.parametrize(argnames, argvalues)(test_fn)
6669

@@ -232,7 +235,6 @@ def test_scripted_smoke(self, info, args_kwargs, device):
232235
[
233236
F.clamp_bounding_box,
234237
F.convert_color_space,
235-
F.convert_image_dtype,
236238
F.get_dimensions,
237239
F.get_image_num_channels,
238240
F.get_image_size,
@@ -312,6 +314,24 @@ def test_alias(alias, target):
312314
assert alias is target
313315

314316

317+
@pytest.mark.parametrize(
318+
("info", "args_kwargs"),
319+
make_info_args_kwargs_params(
320+
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype),
321+
args_kwargs_fn=lambda info: info.sample_inputs_fn(),
322+
),
323+
)
324+
@pytest.mark.parametrize("device", cpu_and_gpu())
325+
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device):
326+
(input, *other_args), kwargs = args_kwargs.load(device)
327+
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32)
328+
329+
output = info.kernel(input, dtype)
330+
331+
assert output.dtype == dtype
332+
assert output.device == input.device
333+
334+
315335
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in
316336
# `prototype_transforms_kernel_infos.py`
317337

torchvision/prototype/transforms/functional/_type_conversion.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torchvision.io.video import read_video
88
from torchvision.prototype import features
99
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer
10-
from torchvision.transforms import functional as _F
10+
from torchvision.transforms import functional as _F, functional_tensor as _FT
1111

1212

1313
@torch.jit.unused
@@ -42,4 +42,77 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) ->
4242
# prevalent and well understood. Thus, we just alias it without deprecating the old name.
4343
to_pil_image = to_image_pil
4444

45-
convert_image_dtype = _F.convert_image_dtype
45+
46+
def _num_value_bits(dtype: torch.dtype) -> int:
47+
if dtype == torch.uint8:
48+
return 8
49+
elif dtype == torch.int8:
50+
return 7
51+
elif dtype == torch.int16:
52+
return 15
53+
elif dtype == torch.int32:
54+
return 31
55+
elif dtype == torch.int64:
56+
return 63
57+
else:
58+
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.")
59+
60+
61+
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor:
62+
if not isinstance(image, torch.Tensor):
63+
raise TypeError("Input img should be Tensor Image")
64+
65+
if image.dtype == dtype:
66+
return image
67+
68+
float_input = image.is_floating_point()
69+
if torch.jit.is_scripting():
70+
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT
71+
float_output = torch.tensor(0, dtype=dtype).is_floating_point()
72+
else:
73+
float_output = dtype.is_floating_point
74+
75+
if float_input:
76+
# float to float
77+
if float_output:
78+
return image.to(dtype)
79+
80+
# float to int
81+
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or (
82+
image.dtype == torch.float64 and dtype == torch.int64
83+
):
84+
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.")
85+
86+
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting
87+
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only
88+
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321
89+
# for a detailed analysis.
90+
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation.
91+
# Instead, we can also multiply by the maximum value plus something close to `1`. See
92+
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details.
93+
eps = 1e-3
94+
max_value = float(_FT._max_value(dtype))
95+
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the
96+
# discrete set `{0, 1}`.
97+
return image.mul(max_value + 1.0 - eps).to(dtype)
98+
else:
99+
# int to float
100+
if float_output:
101+
return image.to(dtype).div_(_FT._max_value(image.dtype))
102+
103+
# int to int
104+
num_value_bits_input = _num_value_bits(image.dtype)
105+
num_value_bits_output = _num_value_bits(dtype)
106+
107+
if num_value_bits_input > num_value_bits_output:
108+
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
109+
else:
110+
# The bitshift kernel is not vectorized
111+
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322
112+
# This results in the multiplication actually being faster.
113+
# TODO: If the bitshift kernel is optimized in core, replace the computation below with
114+
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)`
115+
max_value_input = float(_FT._max_value(dtype))
116+
max_value_output = float(_FT._max_value(image.dtype))
117+
factor = int((max_value_input + 1) // (max_value_output + 1))
118+
return image.to(dtype).mul_(factor)

0 commit comments

Comments
 (0)