Skip to content

Add support for decoding 16bits png #8524

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 6 additions & 12 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
import pytest
import requests
import torch
import torchvision.transforms.functional as F
import torchvision.transforms.v2.functional as F
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
from torchvision.io.image import (
_read_png_16,
decode_gif,
decode_image,
decode_jpeg,
Expand Down Expand Up @@ -211,16 +210,11 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
img_pil = normalize_dimensions(img_pil)

if img_path.endswith("16.png"):
# 16 bits image decoding is supported, but only as a private API
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)

img_lpng = _read_png_16(img_path, mode=mode)
assert img_lpng.dtype == torch.int32
# PIL converts 16 bits pngs in uint8
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)
assert img_lpng.dtype == torch.uint16
# PIL converts 16 bits pngs to uint8
img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
else:
data = read_file(img_path)
img_lpng = decode_fun(data, mode=mode)
Expand Down
30 changes: 27 additions & 3 deletions test/test_transforms_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2076,15 +2076,17 @@ def fn(value):
factor = (output_max_value + 1) // (input_max_value + 1)
return value * factor

return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)

@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
@pytest.mark.parametrize("device", cpu_and_cuda())
@pytest.mark.parametrize("scale", (True, False))
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
if input_dtype.is_floating_point and output_dtype == torch.int64:
pytest.xfail("float to int64 conversion is not supported")
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")

input = make_image(dtype=input_dtype, device=device)

Expand Down Expand Up @@ -2171,6 +2173,28 @@ def test_errors_warnings(self, make_input):
assert out["bbox"].dtype == bbox_dtype
assert out["mask"].dtype == mask_dtype

def test_uint16(self):
# These checks are probably already covered above but since uint16 is a
# newly supported dtype, we want to be extra careful, hence this
# explicit test
img_uint16 = torch.randint(0, 65535, (256, 512), dtype=torch.uint16)

img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True)
img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True)
img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True)

assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8))
assert_close(img_float32, (img_uint16 / 65535))

assert_close(F.to_dtype(img_float32, torch.uint16, scale=True), img_uint16, rtol=0, atol=1)
# Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet
# so we simulate it by scaling down and up again.
assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256))
assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16)

assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8)
assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)


class TestAdjustBrightness:
_CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]
Expand Down
3 changes: 1 addition & 2 deletions torchvision/csrc/io/image/cpu/decode_image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ torch::Tensor decode_image(
if (memcmp(jpeg_signature, datap, 3) == 0) {
return decode_jpeg(data, mode, apply_exif_orientation);
} else if (memcmp(png_signature, datap, 4) == 0) {
return decode_png(
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
return decode_png(data, mode, apply_exif_orientation);
} else if (
memcmp(gif_signature_1, datap, 6) == 0 ||
memcmp(gif_signature_2, datap, 6) == 0) {
Expand Down
59 changes: 16 additions & 43 deletions torchvision/csrc/io/image/cpu/decode_png.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using namespace exif_private;
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
TORCH_CHECK(
false, "decode_png: torchvision not compiled with libPNG support");
Expand All @@ -26,7 +25,6 @@ bool is_little_endian() {
torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode,
bool allow_16_bits,
bool apply_exif_orientation) {
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
// Check that the input tensor dtype is uint8
Expand Down Expand Up @@ -99,12 +97,12 @@ torch::Tensor decode_png(
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
}

auto max_bit_depth = allow_16_bits ? 16 : 8;
auto err_msg = "At most " + std::to_string(max_bit_depth) +
"-bit PNG images are supported currently.";
if (bit_depth > max_bit_depth) {
if (bit_depth > 8 && bit_depth != 16) {
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
TORCH_CHECK(false, err_msg)
TORCH_CHECK(
false,
"bit depth of png image is " + std::to_string(bit_depth) +
". Only <=8 and 16 are supported.")
}

int channels = png_get_channels(png_ptr, info_ptr);
Expand Down Expand Up @@ -199,45 +197,20 @@ torch::Tensor decode_png(
}

auto num_pixels_per_row = width * channels;
auto is_16_bits = bit_depth == 16;
auto tensor = torch::empty(
{int64_t(height), int64_t(width), channels},
bit_depth <= 8 ? torch::kU8 : torch::kI32);

if (bit_depth <= 8) {
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<uint8_t, 3>().data();
}
} else {
// We're reading a 16bits png, but pytorch doesn't support uint16.
// So we read each row in a 16bits tmp_buffer which we then cast into
// a int32 tensor instead.
if (is_little_endian()) {
png_set_swap(png_ptr);
}
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();

// We create a tensor instead of malloc-ing for automatic memory management
auto tmp_buffer_tensor = torch::empty(
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
uint16_t* tmp_buffer =
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();

for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
// Now we copy the uint16 values into the int32 tensor.
for (size_t j = 0; j < num_pixels_per_row; ++j) {
t_ptr[j] = (int32_t)tmp_buffer[j];
}
t_ptr += num_pixels_per_row;
}
t_ptr = tensor.accessor<int32_t, 3>().data();
is_16_bits ? at::kUInt16 : torch::kU8);
if (is_little_endian()) {
png_set_swap(png_ptr);
}
auto t_ptr = (uint8_t*)tensor.data_ptr();
for (int pass = 0; pass < number_of_passes; pass++) {
for (png_uint_32 i = 0; i < height; ++i) {
png_read_row(png_ptr, t_ptr, nullptr);
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
}
t_ptr = (uint8_t*)tensor.data_ptr();
}

int exif_orientation = -1;
Expand Down
1 change: 0 additions & 1 deletion torchvision/csrc/io/image/cpu/decode_png.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ namespace image {
C10_EXPORT torch::Tensor decode_png(
const torch::Tensor& data,
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
bool allow_16_bits = false,
bool apply_exif_orientation = false);

} // namespace image
Expand Down
2 changes: 1 addition & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace image {
static auto registry =
torch::RegisterOperators()
.op("image::decode_gif", &decode_gif)
.op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor",
.op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
&decode_png)
.op("image::encode_png", &encode_png)
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
Expand Down
4 changes: 2 additions & 2 deletions torchvision/datasets/_optical_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from PIL import Image

from ..io.image import _read_png_16
from ..io.image import decode_png, read_file
from .utils import _read_pfm, verify_str_arg
from .vision import VisionDataset

Expand Down Expand Up @@ -481,7 +481,7 @@ def _read_flo(file_name: str) -> np.ndarray:

def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:

flow_and_valid = _read_png_16(file_name).to(torch.float32)
flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
valid_flow_mask = valid_flow_mask.bool()
Expand Down
38 changes: 25 additions & 13 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,14 @@ def decode_png(
) -> torch.Tensor:
"""
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].

The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.

Args:
input (Tensor[1]): a one dimensional uint8 tensor containing
Expand All @@ -93,7 +99,7 @@ def decode_png(
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(decode_png)
output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
return output


Expand Down Expand Up @@ -144,7 +150,7 @@ def decode_jpeg(
) -> torch.Tensor:
"""
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.

The values of the output tensor are uint8 between 0 and 255.

Args:
Expand Down Expand Up @@ -248,8 +254,13 @@ def decode_image(
Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.

Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].
The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.

Args:
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
Expand Down Expand Up @@ -277,8 +288,14 @@ def read_image(
) -> torch.Tensor:
"""
Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
Optionally converts the image to the desired format.
The values of the output tensor are uint8 in [0, 255].

The values of the output tensor are in uint8 in [0, 255] for most cases. If
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
(supported from torchvision ``0.21``. Since uint16 support is limited in
pytorch, we recommend calling
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
after this function to convert the decoded image into a uint8 or float
tensor.

Args:
path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
Expand All @@ -298,11 +315,6 @@ def read_image(
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)


def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
data = read_file(path)
return torch.ops.image.decode_png(data, mode.value, True)


def decode_gif(input: torch.Tensor) -> torch.Tensor:
"""
Decode a GIF image into a 3 or 4 dimensional RGB Tensor.
Expand Down
2 changes: 2 additions & 0 deletions torchvision/transforms/_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ def _max_value(dtype: torch.dtype) -> int:
return 127
elif dtype == torch.int16:
return 32767
elif dtype == torch.uint16:
return 65535
elif dtype == torch.int32:
return 2147483647
elif dtype == torch.int64:
Expand Down
14 changes: 12 additions & 2 deletions torchvision/transforms/v2/functional/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ def _num_value_bits(dtype: torch.dtype) -> int:
return 7
elif dtype == torch.int16:
return 15
elif dtype == torch.uint16:
return 16
elif dtype == torch.int32:
return 31
elif dtype == torch.int64:
Expand Down Expand Up @@ -293,10 +295,18 @@ def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale:
num_value_bits_input = _num_value_bits(image.dtype)
num_value_bits_output = _num_value_bits(dtype)

# TODO: Remove if/else inner blocks once uint16 dtype supports bitwise shift operations.
shift_by = abs(num_value_bits_input - num_value_bits_output)
if num_value_bits_input > num_value_bits_output:
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
if image.dtype == torch.uint16:
return (image / 2 ** (shift_by)).to(dtype)
else:
return image.bitwise_right_shift(shift_by).to(dtype)
else:
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
if dtype == torch.uint16:
return image.to(dtype) * 2 ** (shift_by)
else:
return image.to(dtype).bitwise_left_shift_(shift_by)


# We encourage users to use to_dtype() instead but we keep this for BC
Expand Down
Loading