diff --git a/test/assets/fakedata/logos/rgb_pytorch16.png b/test/assets/fakedata/logos/rgb_pytorch16.png new file mode 100644 index 00000000000..b5e9e35d989 Binary files /dev/null and b/test/assets/fakedata/logos/rgb_pytorch16.png differ diff --git a/test/assets/fakedata/logos/rgbalpha_pytorch16.png b/test/assets/fakedata/logos/rgbalpha_pytorch16.png new file mode 100644 index 00000000000..df1db4d6354 Binary files /dev/null and b/test/assets/fakedata/logos/rgbalpha_pytorch16.png differ diff --git a/test/test_image.py b/test/test_image.py index 9c6a73b8362..7cd74fc915c 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -168,6 +168,11 @@ def test_decode_png(img_path, pil_mode, mode): # TODO: remove once fix is released in PIL. Should be > 8.3.1. img_lpng, img_pil = img_lpng[0], img_pil[0] + if "16" in img_path: + # PIL converts 16 bits pngs in uint8 + assert img_lpng.dtype == torch.int32 + img_lpng = torch.round(img_lpng / (2 ** 16 - 1) * 255).to(torch.uint8) + torch.testing.assert_close(img_lpng, img_pil, atol=tol, rtol=0) diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index ea38272c978..2bd25c3d91a 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -11,6 +11,11 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { } #else +bool is_little_endian() { + uint32_t x = 1; + return *(uint8_t*)&x; +} + torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { // Check that the input tensor dtype is uint8 TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); @@ -72,9 +77,9 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - if (bit_depth > 8) { + if (bit_depth > 16) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, "At most 8-bit PNG images are supported currently.") + TORCH_CHECK(false, "At most 16-bit PNG images are supported currently.") } int channels = png_get_channels(png_ptr, info_ptr); @@ -168,15 +173,46 @@ torch::Tensor decode_png(const torch::Tensor& data, ImageReadMode mode) { png_read_update_info(png_ptr, info_ptr); } - auto tensor = - torch::empty({int64_t(height), int64_t(width), channels}, torch::kU8); - auto ptr = tensor.accessor().data(); - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, ptr, nullptr); - ptr += width * channels; + auto num_pixels_per_row = width * channels; + auto tensor = torch::empty( + {int64_t(height), int64_t(width), channels}, + bit_depth <= 8 ? torch::kU8 : torch::kI32); + + if (bit_depth <= 8) { + auto t_ptr = tensor.accessor().data(); + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, t_ptr, nullptr); + t_ptr += num_pixels_per_row; + } + t_ptr = tensor.accessor().data(); + } + } else { + // We're reading a 16bits png, but pytorch doesn't support uint16. + // So we read each row in a 16bits tmp_buffer which we then cast into + // a int32 tensor instead. + if (is_little_endian()) { + png_set_swap(png_ptr); + } + int32_t* t_ptr = tensor.accessor().data(); + + // We create a tensor instead of malloc-ing for automatic memory management + auto tmp_buffer_tensor = torch::empty( + {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); + uint16_t* tmp_buffer = + (uint16_t*)tmp_buffer_tensor.accessor().data(); + + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr); + // Now we copy the uint16 values into the int32 tensor. + for (size_t j = 0; j < num_pixels_per_row; ++j) { + t_ptr[j] = (int32_t)tmp_buffer[j]; + } + t_ptr += num_pixels_per_row; + } + t_ptr = tensor.accessor().data(); } - ptr = tensor.accessor().data(); } png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); return tensor.permute({2, 0, 1}); diff --git a/torchvision/io/image.py b/torchvision/io/image.py index 3bafabdfd7a..b0969ca3ae3 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -61,7 +61,12 @@ def decode_png(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGE """ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 between 0 and 255. + The values of the output tensor are uint8 in [0, 255], except for + 16-bits pngs which are int32 tensors in [0, 65535]. + + .. warning:: + Should pytorch ever support the uint16 dtype natively, the dtype of the + output for 16-bits pngs will be updated from int32 to uint16. Args: input (Tensor[1]): a one dimensional uint8 tensor containing @@ -188,7 +193,8 @@ def decode_image(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHAN operation to decode the image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 between 0 and 255. + The values of the output tensor are uint8 in [0, 255], except for + 16-bits pngs which are int32 tensors in [0, 65535]. Args: input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the @@ -209,7 +215,8 @@ def read_image(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torc """ Reads a JPEG or PNG image into a 3 dimensional RGB or grayscale Tensor. Optionally converts the image to the desired format. - The values of the output tensor are uint8 between 0 and 255. + The values of the output tensor are uint8 in [0, 255], except for + 16-bits pngs which are int32 tensors in [0, 65535]. Args: path (str): path of the JPEG or PNG image.