Skip to content

Commit ac99fda

Browse files
NicolasHugfacebook-github-bot
authored andcommitted
[fbsync] Add support for decoding 16bits png (#8524)
Reviewed By: vmoens Differential Revision: D60596240 fbshipit-source-id: 4987658c3ee1255a3200232768580a927e371d12
1 parent c07f627 commit ac99fda

File tree

10 files changed

+92
-79
lines changed

10 files changed

+92
-79
lines changed

test/test_image.py

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,10 @@
1010
import pytest
1111
import requests
1212
import torch
13-
import torchvision.transforms.functional as F
13+
import torchvision.transforms.v2.functional as F
1414
from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda
1515
from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence
1616
from torchvision.io.image import (
17-
_read_png_16,
1817
decode_gif,
1918
decode_image,
2019
decode_jpeg,
@@ -211,16 +210,11 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun):
211210
img_pil = normalize_dimensions(img_pil)
212211

213212
if img_path.endswith("16.png"):
214-
# 16 bits image decoding is supported, but only as a private API
215-
# FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public
216-
with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"):
217-
data = read_file(img_path)
218-
img_lpng = decode_fun(data, mode=mode)
219-
220-
img_lpng = _read_png_16(img_path, mode=mode)
221-
assert img_lpng.dtype == torch.int32
222-
# PIL converts 16 bits pngs in uint8
223-
img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8)
213+
data = read_file(img_path)
214+
img_lpng = decode_fun(data, mode=mode)
215+
assert img_lpng.dtype == torch.uint16
216+
# PIL converts 16 bits pngs to uint8
217+
img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True)
224218
else:
225219
data = read_file(img_path)
226220
img_lpng = decode_fun(data, mode=mode)

test/test_transforms_v2.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2076,15 +2076,17 @@ def fn(value):
20762076
factor = (output_max_value + 1) // (input_max_value + 1)
20772077
return value * factor
20782078

2079-
return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device)
2079+
return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device)
20802080

2081-
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8])
2082-
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8])
2081+
@pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
2082+
@pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16])
20832083
@pytest.mark.parametrize("device", cpu_and_cuda())
20842084
@pytest.mark.parametrize("scale", (True, False))
20852085
def test_image_correctness(self, input_dtype, output_dtype, device, scale):
20862086
if input_dtype.is_floating_point and output_dtype == torch.int64:
20872087
pytest.xfail("float to int64 conversion is not supported")
2088+
if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda":
2089+
pytest.xfail("uint8 to uint16 conversion is not supported on cuda")
20882090

20892091
input = make_image(dtype=input_dtype, device=device)
20902092

@@ -2171,6 +2173,28 @@ def test_errors_warnings(self, make_input):
21712173
assert out["bbox"].dtype == bbox_dtype
21722174
assert out["mask"].dtype == mask_dtype
21732175

2176+
def test_uint16(self):
2177+
# These checks are probably already covered above but since uint16 is a
2178+
# newly supported dtype, we want to be extra careful, hence this
2179+
# explicit test
2180+
img_uint16 = torch.randint(0, 65535, (256, 512), dtype=torch.uint16)
2181+
2182+
img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True)
2183+
img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True)
2184+
img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True)
2185+
2186+
assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8))
2187+
assert_close(img_float32, (img_uint16 / 65535))
2188+
2189+
assert_close(F.to_dtype(img_float32, torch.uint16, scale=True), img_uint16, rtol=0, atol=1)
2190+
# Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet
2191+
# so we simulate it by scaling down and up again.
2192+
assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256))
2193+
assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16)
2194+
2195+
assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8)
2196+
assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2)
2197+
21742198

21752199
class TestAdjustBrightness:
21762200
_CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0]

torchvision/csrc/io/image/cpu/decode_image.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@ torch::Tensor decode_image(
3232
if (memcmp(jpeg_signature, datap, 3) == 0) {
3333
return decode_jpeg(data, mode, apply_exif_orientation);
3434
} else if (memcmp(png_signature, datap, 4) == 0) {
35-
return decode_png(
36-
data, mode, /*allow_16_bits=*/false, apply_exif_orientation);
35+
return decode_png(data, mode, apply_exif_orientation);
3736
} else if (
3837
memcmp(gif_signature_1, datap, 6) == 0 ||
3938
memcmp(gif_signature_2, datap, 6) == 0) {

torchvision/csrc/io/image/cpu/decode_png.cpp

Lines changed: 16 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ using namespace exif_private;
1111
torch::Tensor decode_png(
1212
const torch::Tensor& data,
1313
ImageReadMode mode,
14-
bool allow_16_bits,
1514
bool apply_exif_orientation) {
1615
TORCH_CHECK(
1716
false, "decode_png: torchvision not compiled with libPNG support");
@@ -26,7 +25,6 @@ bool is_little_endian() {
2625
torch::Tensor decode_png(
2726
const torch::Tensor& data,
2827
ImageReadMode mode,
29-
bool allow_16_bits,
3028
bool apply_exif_orientation) {
3129
C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png");
3230
// Check that the input tensor dtype is uint8
@@ -99,12 +97,12 @@ torch::Tensor decode_png(
9997
TORCH_CHECK(retval == 1, "Could read image metadata from content.")
10098
}
10199

102-
auto max_bit_depth = allow_16_bits ? 16 : 8;
103-
auto err_msg = "At most " + std::to_string(max_bit_depth) +
104-
"-bit PNG images are supported currently.";
105-
if (bit_depth > max_bit_depth) {
100+
if (bit_depth > 8 && bit_depth != 16) {
106101
png_destroy_read_struct(&png_ptr, &info_ptr, nullptr);
107-
TORCH_CHECK(false, err_msg)
102+
TORCH_CHECK(
103+
false,
104+
"bit depth of png image is " + std::to_string(bit_depth) +
105+
". Only <=8 and 16 are supported.")
108106
}
109107

110108
int channels = png_get_channels(png_ptr, info_ptr);
@@ -199,45 +197,20 @@ torch::Tensor decode_png(
199197
}
200198

201199
auto num_pixels_per_row = width * channels;
200+
auto is_16_bits = bit_depth == 16;
202201
auto tensor = torch::empty(
203202
{int64_t(height), int64_t(width), channels},
204-
bit_depth <= 8 ? torch::kU8 : torch::kI32);
205-
206-
if (bit_depth <= 8) {
207-
auto t_ptr = tensor.accessor<uint8_t, 3>().data();
208-
for (int pass = 0; pass < number_of_passes; pass++) {
209-
for (png_uint_32 i = 0; i < height; ++i) {
210-
png_read_row(png_ptr, t_ptr, nullptr);
211-
t_ptr += num_pixels_per_row;
212-
}
213-
t_ptr = tensor.accessor<uint8_t, 3>().data();
214-
}
215-
} else {
216-
// We're reading a 16bits png, but pytorch doesn't support uint16.
217-
// So we read each row in a 16bits tmp_buffer which we then cast into
218-
// a int32 tensor instead.
219-
if (is_little_endian()) {
220-
png_set_swap(png_ptr);
221-
}
222-
int32_t* t_ptr = tensor.accessor<int32_t, 3>().data();
223-
224-
// We create a tensor instead of malloc-ing for automatic memory management
225-
auto tmp_buffer_tensor = torch::empty(
226-
{int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8);
227-
uint16_t* tmp_buffer =
228-
(uint16_t*)tmp_buffer_tensor.accessor<uint8_t, 1>().data();
229-
230-
for (int pass = 0; pass < number_of_passes; pass++) {
231-
for (png_uint_32 i = 0; i < height; ++i) {
232-
png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr);
233-
// Now we copy the uint16 values into the int32 tensor.
234-
for (size_t j = 0; j < num_pixels_per_row; ++j) {
235-
t_ptr[j] = (int32_t)tmp_buffer[j];
236-
}
237-
t_ptr += num_pixels_per_row;
238-
}
239-
t_ptr = tensor.accessor<int32_t, 3>().data();
203+
is_16_bits ? at::kUInt16 : torch::kU8);
204+
if (is_little_endian()) {
205+
png_set_swap(png_ptr);
206+
}
207+
auto t_ptr = (uint8_t*)tensor.data_ptr();
208+
for (int pass = 0; pass < number_of_passes; pass++) {
209+
for (png_uint_32 i = 0; i < height; ++i) {
210+
png_read_row(png_ptr, t_ptr, nullptr);
211+
t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1);
240212
}
213+
t_ptr = (uint8_t*)tensor.data_ptr();
241214
}
242215

243216
int exif_orientation = -1;

torchvision/csrc/io/image/cpu/decode_png.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ namespace image {
99
C10_EXPORT torch::Tensor decode_png(
1010
const torch::Tensor& data,
1111
ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED,
12-
bool allow_16_bits = false,
1312
bool apply_exif_orientation = false);
1413

1514
} // namespace image

torchvision/csrc/io/image/image.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ namespace image {
1616
static auto registry =
1717
torch::RegisterOperators()
1818
.op("image::decode_gif", &decode_gif)
19-
.op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor",
19+
.op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",
2020
&decode_png)
2121
.op("image::encode_png", &encode_png)
2222
.op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor",

torchvision/datasets/_optical_flow.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from PIL import Image
1111

12-
from ..io.image import _read_png_16
12+
from ..io.image import decode_png, read_file
1313
from .utils import _read_pfm, verify_str_arg
1414
from .vision import VisionDataset
1515

@@ -481,7 +481,7 @@ def _read_flo(file_name: str) -> np.ndarray:
481481

482482
def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]:
483483

484-
flow_and_valid = _read_png_16(file_name).to(torch.float32)
484+
flow_and_valid = decode_png(read_file(file_name)).to(torch.float32)
485485
flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :]
486486
flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive
487487
valid_flow_mask = valid_flow_mask.bool()

torchvision/io/image.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,14 @@ def decode_png(
7575
) -> torch.Tensor:
7676
"""
7777
Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor.
78-
Optionally converts the image to the desired format.
79-
The values of the output tensor are uint8 in [0, 255].
78+
79+
The values of the output tensor are in uint8 in [0, 255] for most cases. If
80+
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
81+
(supported from torchvision ``0.21``. Since uint16 support is limited in
82+
pytorch, we recommend calling
83+
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
84+
after this function to convert the decoded image into a uint8 or float
85+
tensor.
8086
8187
Args:
8288
input (Tensor[1]): a one dimensional uint8 tensor containing
@@ -93,7 +99,7 @@ def decode_png(
9399
"""
94100
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
95101
_log_api_usage_once(decode_png)
96-
output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation)
102+
output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation)
97103
return output
98104

99105

@@ -144,7 +150,7 @@ def decode_jpeg(
144150
) -> torch.Tensor:
145151
"""
146152
Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor.
147-
Optionally converts the image to the desired format.
153+
148154
The values of the output tensor are uint8 between 0 and 255.
149155
150156
Args:
@@ -248,8 +254,13 @@ def decode_image(
248254
Detect whether an image is a JPEG, PNG or GIF and performs the appropriate
249255
operation to decode the image into a 3 dimensional RGB or grayscale Tensor.
250256
251-
Optionally converts the image to the desired format.
252-
The values of the output tensor are uint8 in [0, 255].
257+
The values of the output tensor are in uint8 in [0, 255] for most cases. If
258+
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
259+
(supported from torchvision ``0.21``. Since uint16 support is limited in
260+
pytorch, we recommend calling
261+
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
262+
after this function to convert the decoded image into a uint8 or float
263+
tensor.
253264
254265
Args:
255266
input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the
@@ -277,8 +288,14 @@ def read_image(
277288
) -> torch.Tensor:
278289
"""
279290
Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor.
280-
Optionally converts the image to the desired format.
281-
The values of the output tensor are uint8 in [0, 255].
291+
292+
The values of the output tensor are in uint8 in [0, 255] for most cases. If
293+
the image is a 16-bit png, then the output tensor is uint16 in [0, 65535]
294+
(supported from torchvision ``0.21``. Since uint16 support is limited in
295+
pytorch, we recommend calling
296+
:func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True``
297+
after this function to convert the decoded image into a uint8 or float
298+
tensor.
282299
283300
Args:
284301
path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image.
@@ -298,11 +315,6 @@ def read_image(
298315
return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation)
299316

300317

301-
def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
302-
data = read_file(path)
303-
return torch.ops.image.decode_png(data, mode.value, True)
304-
305-
306318
def decode_gif(input: torch.Tensor) -> torch.Tensor:
307319
"""
308320
Decode a GIF image into a 3 or 4 dimensional RGB Tensor.

torchvision/transforms/_functional_tensor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def _max_value(dtype: torch.dtype) -> int:
4545
return 127
4646
elif dtype == torch.int16:
4747
return 32767
48+
elif dtype == torch.uint16:
49+
return 65535
4850
elif dtype == torch.int32:
4951
return 2147483647
5052
elif dtype == torch.int64:

torchvision/transforms/v2/functional/_misc.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def _num_value_bits(dtype: torch.dtype) -> int:
237237
return 7
238238
elif dtype == torch.int16:
239239
return 15
240+
elif dtype == torch.uint16:
241+
return 16
240242
elif dtype == torch.int32:
241243
return 31
242244
elif dtype == torch.int64:
@@ -293,10 +295,18 @@ def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale:
293295
num_value_bits_input = _num_value_bits(image.dtype)
294296
num_value_bits_output = _num_value_bits(dtype)
295297

298+
# TODO: Remove if/else inner blocks once uint16 dtype supports bitwise shift operations.
299+
shift_by = abs(num_value_bits_input - num_value_bits_output)
296300
if num_value_bits_input > num_value_bits_output:
297-
return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype)
301+
if image.dtype == torch.uint16:
302+
return (image / 2 ** (shift_by)).to(dtype)
303+
else:
304+
return image.bitwise_right_shift(shift_by).to(dtype)
298305
else:
299-
return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input)
306+
if dtype == torch.uint16:
307+
return image.to(dtype) * 2 ** (shift_by)
308+
else:
309+
return image.to(dtype).bitwise_left_shift_(shift_by)
300310

301311

302312
# We encourage users to use to_dtype() instead but we keep this for BC

0 commit comments

Comments
 (0)