-
Notifications
You must be signed in to change notification settings - Fork 7.1k
improve perf on convert_image_dtype and add tests #6795
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
328190f
9d03609
13feab9
2cafe05
7e1843f
03c409d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import decimal | ||
import functools | ||
import itertools | ||
import math | ||
|
@@ -21,6 +22,7 @@ | |
mark_framework_limitation, | ||
TestMark, | ||
) | ||
from torch.utils._pytree import tree_map | ||
from torchvision.prototype import features | ||
from torchvision.transforms.functional_tensor import _max_value as get_max_value | ||
|
||
|
@@ -1947,3 +1949,119 @@ def sample_inputs_normalize_video(): | |
), | ||
] | ||
) | ||
|
||
|
||
def sample_inputs_convert_image_dtype(): | ||
for input_dtype, output_dtype in itertools.product( | ||
[torch.uint8, torch.int64, torch.float32, torch.float64], repeat=2 | ||
): | ||
if input_dtype.is_floating_point and output_dtype == torch.int64: | ||
# conversion cannot be performed safely | ||
continue | ||
|
||
for image_loader in make_image_loaders( | ||
sizes=["random"], color_spaces=[features.ColorSpace.RGB], dtypes=[input_dtype] | ||
): | ||
yield ArgsKwargs(image_loader, dtype=output_dtype) | ||
|
||
yield ArgsKwargs(make_image_loader(color_space=features.ColorSpace.RGB), dtype=torch.uint8) | ||
|
||
|
||
def reference_convert_image_dtype(image, dtype=torch.float): | ||
input_dtype = image.dtype | ||
output_dtype = dtype | ||
|
||
if output_dtype == input_dtype: | ||
return image | ||
|
||
def fn(value): | ||
if input_dtype.is_floating_point: | ||
if output_dtype.is_floating_point: | ||
return value | ||
else: | ||
return int(decimal.Decimal(value) * torch.iinfo(output_dtype).max) | ||
else: | ||
input_max_value = torch.iinfo(input_dtype).max | ||
|
||
if output_dtype.is_floating_point: | ||
return float(decimal.Decimal(value) / input_max_value) | ||
else: | ||
output_max_value = torch.iinfo(output_dtype).max | ||
|
||
if input_max_value > output_max_value: | ||
factor = (input_max_value + 1) // (output_max_value + 1) | ||
return value // factor | ||
else: | ||
factor = (output_max_value + 1) // (input_max_value + 1) | ||
return value * factor | ||
Comment on lines
+1991
to
+1996
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Pointer for my comment above. |
||
|
||
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype) | ||
|
||
|
||
def reference_inputs_convert_image_dtype(): | ||
for input_dtype, output_dtype in itertools.product( | ||
[ | ||
torch.uint8, | ||
torch.int16, | ||
torch.int32, | ||
torch.int64, | ||
torch.float16, | ||
torch.float32, | ||
torch.float64, | ||
torch.bfloat16, | ||
], | ||
repeat=2, | ||
): | ||
if (input_dtype == torch.float32 and output_dtype in {torch.int32, torch.int64}) or ( | ||
input_dtype == torch.float64 and output_dtype == torch.int64 | ||
): | ||
continue | ||
|
||
if input_dtype.is_floating_point: | ||
data = [0.0, 0.5, 1.0] | ||
else: | ||
max_value = torch.iinfo(input_dtype).max | ||
data = [0, max_value // 2, max_value] | ||
image = torch.tensor(data, dtype=input_dtype) | ||
|
||
yield ArgsKwargs(image, dtype=output_dtype) | ||
|
||
|
||
KERNEL_INFOS.extend( | ||
[ | ||
KernelInfo( | ||
F.convert_image_dtype, | ||
sample_inputs_fn=sample_inputs_convert_image_dtype, | ||
reference_fn=reference_convert_image_dtype, | ||
reference_inputs_fn=reference_inputs_convert_image_dtype, | ||
test_marks=[ | ||
TestMark( | ||
("TestKernels", "test_scripted_vs_eager"), | ||
pytest.mark.filterwarnings(f"ignore:{re.escape('operator() profile_node %41')}:UserWarning"), | ||
), | ||
TestMark( | ||
("TestKernels", "test_dtype_and_device_consistency"), | ||
pytest.mark.skip(reason="`convert_dtype_*` kernels convert the dtype by design"), | ||
condition=lambda args_kwargs: args_kwargs.args[0].dtype | ||
!= args_kwargs.kwargs.get("dtype", torch.float32), | ||
), | ||
TestMark( | ||
("TestKernels", "test_against_reference"), | ||
pytest.mark.xfail(reason="Conversion overflows"), | ||
condition=lambda args_kwargs: ( | ||
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
and not args_kwargs.kwargs["dtype"].is_floating_point | ||
) | ||
or ( | ||
args_kwargs.args[0].dtype in {torch.float16, torch.bfloat16} | ||
and args_kwargs.kwargs["dtype"] == torch.int64 | ||
) | ||
or ( | ||
args_kwargs.args[0].dtype in {torch.int32, torch.int64} | ||
and args_kwargs.kwargs["dtype"] == torch.float16 | ||
), | ||
Comment on lines
+2051
to
+2062
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm going to open an issue soon detailing what is happening in these cases and how we could mitigate it. |
||
), | ||
], | ||
), | ||
] | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,6 +26,20 @@ def script(fn): | |
raise AssertionError(f"Trying to `torch.jit.script` '{fn.__name__}' raised the error above.") from error | ||
|
||
|
||
def make_info_args_kwargs_params(info, *, args_kwargs_fn, test_id=None): | ||
args_kwargs = list(args_kwargs_fn(info)) | ||
idx_field_len = len(str(len(args_kwargs))) | ||
return [ | ||
pytest.param( | ||
info, | ||
args_kwargs_, | ||
marks=info.get_marks(test_id, args_kwargs_) if test_id else [], | ||
id=f"{info.id}-{idx:0{idx_field_len}}", | ||
) | ||
for idx, args_kwargs_ in enumerate(args_kwargs) | ||
] | ||
|
||
|
||
def make_info_args_kwargs_parametrization(infos, *, args_kwargs_fn, condition=None): | ||
if condition is None: | ||
|
||
|
@@ -49,18 +63,7 @@ def decorator(test_fn): | |
if not condition(info): | ||
continue | ||
|
||
args_kwargs = list(args_kwargs_fn(info)) | ||
idx_field_len = len(str(len(args_kwargs))) | ||
|
||
for idx, args_kwargs_ in enumerate(args_kwargs): | ||
argvalues.append( | ||
pytest.param( | ||
info, | ||
args_kwargs_, | ||
marks=info.get_marks(test_id, args_kwargs_), | ||
id=f"{info.id}-{idx:0{idx_field_len}}", | ||
) | ||
) | ||
argvalues.extend(make_info_args_kwargs_params(info, args_kwargs_fn=args_kwargs_fn, test_id=test_id)) | ||
|
||
return pytest.mark.parametrize(argnames, argvalues)(test_fn) | ||
|
||
|
@@ -232,7 +235,6 @@ def test_scripted_smoke(self, info, args_kwargs, device): | |
[ | ||
F.clamp_bounding_box, | ||
F.convert_color_space, | ||
F.convert_image_dtype, | ||
F.get_dimensions, | ||
F.get_image_num_channels, | ||
F.get_image_size, | ||
|
@@ -312,6 +314,24 @@ def test_alias(alias, target): | |
assert alias is target | ||
|
||
|
||
@pytest.mark.parametrize( | ||
("info", "args_kwargs"), | ||
make_info_args_kwargs_params( | ||
next(info for info in KERNEL_INFOS if info.kernel is F.convert_image_dtype), | ||
args_kwargs_fn=lambda info: info.sample_inputs_fn(), | ||
), | ||
) | ||
Comment on lines
+317
to
+323
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is rather convoluted to get the sample inputs for a single kernel. I'll refactor later since this is low priority right now. |
||
@pytest.mark.parametrize("device", cpu_and_gpu()) | ||
def test_dtype_and_device_convert_image_dtype(info, args_kwargs, device): | ||
(input, *other_args), kwargs = args_kwargs.load(device) | ||
dtype = other_args[0] if other_args else kwargs.get("dtype", torch.float32) | ||
|
||
output = info.kernel(input, dtype) | ||
|
||
assert output.dtype == dtype | ||
assert output.device == input.device | ||
|
||
|
||
# TODO: All correctness checks below this line should be ported to be references on a `KernelInfo` in | ||
# `prototype_transforms_kernel_infos.py` | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ | |
from torchvision.io.video import read_video | ||
from torchvision.prototype import features | ||
from torchvision.prototype.utils._internal import ReadOnlyTensorBuffer | ||
from torchvision.transforms import functional as _F | ||
from torchvision.transforms import functional as _F, functional_tensor as _FT | ||
|
||
|
||
@torch.jit.unused | ||
|
@@ -42,4 +42,77 @@ def to_image_tensor(image: Union[torch.Tensor, PIL.Image.Image, np.ndarray]) -> | |
# prevalent and well understood. Thus, we just alias it without deprecating the old name. | ||
to_pil_image = to_image_pil | ||
|
||
convert_image_dtype = _F.convert_image_dtype | ||
|
||
def _num_value_bits(dtype: torch.dtype) -> int: | ||
if dtype == torch.uint8: | ||
return 8 | ||
elif dtype == torch.int8: | ||
return 7 | ||
elif dtype == torch.int16: | ||
return 15 | ||
elif dtype == torch.int32: | ||
return 31 | ||
elif dtype == torch.int64: | ||
return 63 | ||
else: | ||
raise TypeError(f"Number of value bits is only defined for integer dtypes, but got {dtype}.") | ||
|
||
|
||
def convert_image_dtype(image: torch.Tensor, dtype: torch.dtype = torch.float) -> torch.Tensor: | ||
if not isinstance(image, torch.Tensor): | ||
raise TypeError("Input img should be Tensor Image") | ||
|
||
if image.dtype == dtype: | ||
return image | ||
|
||
float_input = image.is_floating_point() | ||
if torch.jit.is_scripting(): | ||
# TODO: remove this branch as soon as `dtype.is_floating_point` is supported by JIT | ||
float_output = torch.tensor(0, dtype=dtype).is_floating_point() | ||
else: | ||
float_output = dtype.is_floating_point | ||
|
||
if float_input: | ||
# float to float | ||
if float_output: | ||
return image.to(dtype) | ||
|
||
# float to int | ||
if (image.dtype == torch.float32 and dtype in (torch.int32, torch.int64)) or ( | ||
image.dtype == torch.float64 and dtype == torch.int64 | ||
): | ||
raise RuntimeError(f"The conversion from {image.dtype} to {dtype} cannot be performed safely.") | ||
|
||
# For data in the range `[0.0, 1.0]`, just multiplying by the maximum value of the integer range and converting | ||
# to the integer dtype is not sufficient. For example, `torch.rand(...).mul(255).to(torch.uint8)` will only | ||
# be `255` if the input is exactly `1.0`. See https://github.com/pytorch/vision/pull/2078#issuecomment-612045321 | ||
# for a detailed analysis. | ||
# To mitigate this, we could round before we convert to the integer dtype, but this is an extra operation. | ||
# Instead, we can also multiply by the maximum value plus something close to `1`. See | ||
# https://github.com/pytorch/vision/pull/2078#issuecomment-613524965 for details. | ||
eps = 1e-3 | ||
max_value = float(_FT._max_value(dtype)) | ||
# We need to scale first since the conversion would otherwise turn the input range `[0.0, 1.0]` into the | ||
# discrete set `{0, 1}`. | ||
return image.mul(max_value + 1.0 - eps).to(dtype) | ||
else: | ||
# int to float | ||
if float_output: | ||
return image.to(dtype).div_(_FT._max_value(image.dtype)) | ||
|
||
# int to int | ||
num_value_bits_input = _num_value_bits(image.dtype) | ||
num_value_bits_output = _num_value_bits(dtype) | ||
|
||
if num_value_bits_input > num_value_bits_output: | ||
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) | ||
else: | ||
# The bitshift kernel is not vectorized | ||
# https://github.com/pytorch/pytorch/blob/703c19008df4700b6a522b0ae5c4b6d5ffc0906f/aten/src/ATen/native/cpu/BinaryOpsKernel.cpp#L315-L322 | ||
# This results in the multiplication actually being faster. | ||
# TODO: If the bitshift kernel is optimized in core, replace the computation below with | ||
# `image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)` | ||
Comment on lines
+110
to
+114
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Per comment. The same applies to the |
||
max_value_input = float(_FT._max_value(dtype)) | ||
max_value_output = float(_FT._max_value(image.dtype)) | ||
factor = int((max_value_input + 1) // (max_value_output + 1)) | ||
return image.to(dtype).mul_(factor) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This gives us arbitrary floating point precision for the intermediate calculations, which is what we want for the reference function. You can see from the xfails I needed to add below, that we need this in some cases.