diff --git a/test/test_image.py b/test/test_image.py index 86018dccc42..005cf41b1ca 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -10,11 +10,10 @@ import pytest import requests import torch -import torchvision.transforms.functional as F +import torchvision.transforms.v2.functional as F from common_utils import assert_equal, cpu_and_cuda, IN_OSS_CI, needs_cuda from PIL import __version__ as PILLOW_VERSION, Image, ImageOps, ImageSequence from torchvision.io.image import ( - _read_png_16, decode_gif, decode_image, decode_jpeg, @@ -211,16 +210,11 @@ def test_decode_png(img_path, pil_mode, mode, scripted, decode_fun): img_pil = normalize_dimensions(img_pil) if img_path.endswith("16.png"): - # 16 bits image decoding is supported, but only as a private API - # FIXME: see https://github.com/pytorch/vision/issues/4731 for potential solutions to making it public - with pytest.raises(RuntimeError, match="At most 8-bit PNG images are supported"): - data = read_file(img_path) - img_lpng = decode_fun(data, mode=mode) - - img_lpng = _read_png_16(img_path, mode=mode) - assert img_lpng.dtype == torch.int32 - # PIL converts 16 bits pngs in uint8 - img_lpng = torch.round(img_lpng / (2**16 - 1) * 255).to(torch.uint8) + data = read_file(img_path) + img_lpng = decode_fun(data, mode=mode) + assert img_lpng.dtype == torch.uint16 + # PIL converts 16 bits pngs to uint8 + img_lpng = F.to_dtype(img_lpng, torch.uint8, scale=True) else: data = read_file(img_path) img_lpng = decode_fun(data, mode=mode) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 07e3d75df6d..f9218c3e840 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2076,15 +2076,17 @@ def fn(value): factor = (output_max_value + 1) // (input_max_value + 1) return value * factor - return torch.tensor(tree_map(fn, image.tolist()), dtype=dtype, device=image.device) + return torch.tensor(tree_map(fn, image.tolist())).to(dtype=output_dtype, device=image.device) - @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8]) - @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8]) + @pytest.mark.parametrize("input_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) + @pytest.mark.parametrize("output_dtype", [torch.float32, torch.float64, torch.uint8, torch.uint16]) @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("scale", (True, False)) def test_image_correctness(self, input_dtype, output_dtype, device, scale): if input_dtype.is_floating_point and output_dtype == torch.int64: pytest.xfail("float to int64 conversion is not supported") + if input_dtype == torch.uint8 and output_dtype == torch.uint16 and device == "cuda": + pytest.xfail("uint8 to uint16 conversion is not supported on cuda") input = make_image(dtype=input_dtype, device=device) @@ -2171,6 +2173,28 @@ def test_errors_warnings(self, make_input): assert out["bbox"].dtype == bbox_dtype assert out["mask"].dtype == mask_dtype + def test_uint16(self): + # These checks are probably already covered above but since uint16 is a + # newly supported dtype, we want to be extra careful, hence this + # explicit test + img_uint16 = torch.randint(0, 65535, (256, 512), dtype=torch.uint16) + + img_uint8 = F.to_dtype(img_uint16, torch.uint8, scale=True) + img_float32 = F.to_dtype(img_uint16, torch.float32, scale=True) + img_int32 = F.to_dtype(img_uint16, torch.int32, scale=True) + + assert_equal(img_uint8, (img_uint16 / 256).to(torch.uint8)) + assert_close(img_float32, (img_uint16 / 65535)) + + assert_close(F.to_dtype(img_float32, torch.uint16, scale=True), img_uint16, rtol=0, atol=1) + # Ideally we'd check against (img_uint16 & 0xFF00) but bitwise and isn't supported for it yet + # so we simulate it by scaling down and up again. + assert_equal(F.to_dtype(img_uint8, torch.uint16, scale=True), ((img_uint16 / 256).to(torch.uint16) * 256)) + assert_equal(F.to_dtype(img_int32, torch.uint16, scale=True), img_uint16) + + assert_equal(F.to_dtype(img_float32, torch.uint8, scale=True), img_uint8) + assert_close(F.to_dtype(img_uint8, torch.float32, scale=True), img_float32, rtol=0, atol=1e-2) + class TestAdjustBrightness: _CORRECTNESS_BRIGHTNESS_FACTORS = [0.5, 0.0, 1.0, 5.0] diff --git a/torchvision/csrc/io/image/cpu/decode_image.cpp b/torchvision/csrc/io/image/cpu/decode_image.cpp index 1f09da17597..3a18406042e 100644 --- a/torchvision/csrc/io/image/cpu/decode_image.cpp +++ b/torchvision/csrc/io/image/cpu/decode_image.cpp @@ -32,8 +32,7 @@ torch::Tensor decode_image( if (memcmp(jpeg_signature, datap, 3) == 0) { return decode_jpeg(data, mode, apply_exif_orientation); } else if (memcmp(png_signature, datap, 4) == 0) { - return decode_png( - data, mode, /*allow_16_bits=*/false, apply_exif_orientation); + return decode_png(data, mode, apply_exif_orientation); } else if ( memcmp(gif_signature_1, datap, 6) == 0 || memcmp(gif_signature_2, datap, 6) == 0) { diff --git a/torchvision/csrc/io/image/cpu/decode_png.cpp b/torchvision/csrc/io/image/cpu/decode_png.cpp index ab4087fdfe2..ac14ae934a4 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.cpp +++ b/torchvision/csrc/io/image/cpu/decode_png.cpp @@ -11,7 +11,6 @@ using namespace exif_private; torch::Tensor decode_png( const torch::Tensor& data, ImageReadMode mode, - bool allow_16_bits, bool apply_exif_orientation) { TORCH_CHECK( false, "decode_png: torchvision not compiled with libPNG support"); @@ -26,7 +25,6 @@ bool is_little_endian() { torch::Tensor decode_png( const torch::Tensor& data, ImageReadMode mode, - bool allow_16_bits, bool apply_exif_orientation) { C10_LOG_API_USAGE_ONCE("torchvision.csrc.io.image.cpu.decode_png.decode_png"); // Check that the input tensor dtype is uint8 @@ -99,12 +97,12 @@ torch::Tensor decode_png( TORCH_CHECK(retval == 1, "Could read image metadata from content.") } - auto max_bit_depth = allow_16_bits ? 16 : 8; - auto err_msg = "At most " + std::to_string(max_bit_depth) + - "-bit PNG images are supported currently."; - if (bit_depth > max_bit_depth) { + if (bit_depth > 8 && bit_depth != 16) { png_destroy_read_struct(&png_ptr, &info_ptr, nullptr); - TORCH_CHECK(false, err_msg) + TORCH_CHECK( + false, + "bit depth of png image is " + std::to_string(bit_depth) + + ". Only <=8 and 16 are supported.") } int channels = png_get_channels(png_ptr, info_ptr); @@ -199,45 +197,20 @@ torch::Tensor decode_png( } auto num_pixels_per_row = width * channels; + auto is_16_bits = bit_depth == 16; auto tensor = torch::empty( {int64_t(height), int64_t(width), channels}, - bit_depth <= 8 ? torch::kU8 : torch::kI32); - - if (bit_depth <= 8) { - auto t_ptr = tensor.accessor().data(); - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, t_ptr, nullptr); - t_ptr += num_pixels_per_row; - } - t_ptr = tensor.accessor().data(); - } - } else { - // We're reading a 16bits png, but pytorch doesn't support uint16. - // So we read each row in a 16bits tmp_buffer which we then cast into - // a int32 tensor instead. - if (is_little_endian()) { - png_set_swap(png_ptr); - } - int32_t* t_ptr = tensor.accessor().data(); - - // We create a tensor instead of malloc-ing for automatic memory management - auto tmp_buffer_tensor = torch::empty( - {int64_t(num_pixels_per_row * sizeof(uint16_t))}, torch::kU8); - uint16_t* tmp_buffer = - (uint16_t*)tmp_buffer_tensor.accessor().data(); - - for (int pass = 0; pass < number_of_passes; pass++) { - for (png_uint_32 i = 0; i < height; ++i) { - png_read_row(png_ptr, (uint8_t*)tmp_buffer, nullptr); - // Now we copy the uint16 values into the int32 tensor. - for (size_t j = 0; j < num_pixels_per_row; ++j) { - t_ptr[j] = (int32_t)tmp_buffer[j]; - } - t_ptr += num_pixels_per_row; - } - t_ptr = tensor.accessor().data(); + is_16_bits ? at::kUInt16 : torch::kU8); + if (is_little_endian()) { + png_set_swap(png_ptr); + } + auto t_ptr = (uint8_t*)tensor.data_ptr(); + for (int pass = 0; pass < number_of_passes; pass++) { + for (png_uint_32 i = 0; i < height; ++i) { + png_read_row(png_ptr, t_ptr, nullptr); + t_ptr += num_pixels_per_row * (is_16_bits ? 2 : 1); } + t_ptr = (uint8_t*)tensor.data_ptr(); } int exif_orientation = -1; diff --git a/torchvision/csrc/io/image/cpu/decode_png.h b/torchvision/csrc/io/image/cpu/decode_png.h index b091f15e35f..0866711e987 100644 --- a/torchvision/csrc/io/image/cpu/decode_png.h +++ b/torchvision/csrc/io/image/cpu/decode_png.h @@ -9,7 +9,6 @@ namespace image { C10_EXPORT torch::Tensor decode_png( const torch::Tensor& data, ImageReadMode mode = IMAGE_READ_MODE_UNCHANGED, - bool allow_16_bits = false, bool apply_exif_orientation = false); } // namespace image diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 68267b72604..e351ed425b5 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -16,7 +16,7 @@ namespace image { static auto registry = torch::RegisterOperators() .op("image::decode_gif", &decode_gif) - .op("image::decode_png(Tensor data, int mode, bool allow_16_bits = False, bool apply_exif_orientation=False) -> Tensor", + .op("image::decode_png(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", &decode_png) .op("image::encode_png", &encode_png) .op("image::decode_jpeg(Tensor data, int mode, bool apply_exif_orientation=False) -> Tensor", diff --git a/torchvision/datasets/_optical_flow.py b/torchvision/datasets/_optical_flow.py index 6d8b852a5d0..e8d6247f03f 100644 --- a/torchvision/datasets/_optical_flow.py +++ b/torchvision/datasets/_optical_flow.py @@ -9,7 +9,7 @@ import torch from PIL import Image -from ..io.image import _read_png_16 +from ..io.image import decode_png, read_file from .utils import _read_pfm, verify_str_arg from .vision import VisionDataset @@ -481,7 +481,7 @@ def _read_flo(file_name: str) -> np.ndarray: def _read_16bits_png_with_flow_and_valid_mask(file_name: str) -> Tuple[np.ndarray, np.ndarray]: - flow_and_valid = _read_png_16(file_name).to(torch.float32) + flow_and_valid = decode_png(read_file(file_name)).to(torch.float32) flow, valid_flow_mask = flow_and_valid[:2, :, :], flow_and_valid[2, :, :] flow = (flow - 2**15) / 64 # This conversion is explained somewhere on the kitti archive valid_flow_mask = valid_flow_mask.bool() diff --git a/torchvision/io/image.py b/torchvision/io/image.py index df2fdef3580..debef443f7a 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -75,8 +75,14 @@ def decode_png( ) -> torch.Tensor: """ Decodes a PNG image into a 3 dimensional RGB or grayscale Tensor. - Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255]. + + The values of the output tensor are in uint8 in [0, 255] for most cases. If + the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] + (supported from torchvision ``0.21``. Since uint16 support is limited in + pytorch, we recommend calling + :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` + after this function to convert the decoded image into a uint8 or float + tensor. Args: input (Tensor[1]): a one dimensional uint8 tensor containing @@ -93,7 +99,7 @@ def decode_png( """ if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(decode_png) - output = torch.ops.image.decode_png(input, mode.value, False, apply_exif_orientation) + output = torch.ops.image.decode_png(input, mode.value, apply_exif_orientation) return output @@ -144,7 +150,7 @@ def decode_jpeg( ) -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB or grayscale Tensor. - Optionally converts the image to the desired format. + The values of the output tensor are uint8 between 0 and 255. Args: @@ -248,8 +254,13 @@ def decode_image( Detect whether an image is a JPEG, PNG or GIF and performs the appropriate operation to decode the image into a 3 dimensional RGB or grayscale Tensor. - Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255]. + The values of the output tensor are in uint8 in [0, 255] for most cases. If + the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] + (supported from torchvision ``0.21``. Since uint16 support is limited in + pytorch, we recommend calling + :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` + after this function to convert the decoded image into a uint8 or float + tensor. Args: input (Tensor): a one dimensional uint8 tensor containing the raw bytes of the @@ -277,8 +288,14 @@ def read_image( ) -> torch.Tensor: """ Reads a JPEG, PNG or GIF image into a 3 dimensional RGB or grayscale Tensor. - Optionally converts the image to the desired format. - The values of the output tensor are uint8 in [0, 255]. + + The values of the output tensor are in uint8 in [0, 255] for most cases. If + the image is a 16-bit png, then the output tensor is uint16 in [0, 65535] + (supported from torchvision ``0.21``. Since uint16 support is limited in + pytorch, we recommend calling + :func:`torchvision.transforms.v2.functional.to_dtype()` with ``scale=True`` + after this function to convert the decoded image into a uint8 or float + tensor. Args: path (str or ``pathlib.Path``): path of the JPEG, PNG or GIF image. @@ -298,11 +315,6 @@ def read_image( return decode_image(data, mode, apply_exif_orientation=apply_exif_orientation) -def _read_png_16(path: str, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: - data = read_file(path) - return torch.ops.image.decode_png(data, mode.value, True) - - def decode_gif(input: torch.Tensor) -> torch.Tensor: """ Decode a GIF image into a 3 or 4 dimensional RGB Tensor. diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 348f01bb1e6..618bbfbab7c 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -45,6 +45,8 @@ def _max_value(dtype: torch.dtype) -> int: return 127 elif dtype == torch.int16: return 32767 + elif dtype == torch.uint16: + return 65535 elif dtype == torch.int32: return 2147483647 elif dtype == torch.int64: diff --git a/torchvision/transforms/v2/functional/_misc.py b/torchvision/transforms/v2/functional/_misc.py index daf6c5560db..8b20473e6e7 100644 --- a/torchvision/transforms/v2/functional/_misc.py +++ b/torchvision/transforms/v2/functional/_misc.py @@ -237,6 +237,8 @@ def _num_value_bits(dtype: torch.dtype) -> int: return 7 elif dtype == torch.int16: return 15 + elif dtype == torch.uint16: + return 16 elif dtype == torch.int32: return 31 elif dtype == torch.int64: @@ -293,10 +295,18 @@ def to_dtype_image(image: torch.Tensor, dtype: torch.dtype = torch.float, scale: num_value_bits_input = _num_value_bits(image.dtype) num_value_bits_output = _num_value_bits(dtype) + # TODO: Remove if/else inner blocks once uint16 dtype supports bitwise shift operations. + shift_by = abs(num_value_bits_input - num_value_bits_output) if num_value_bits_input > num_value_bits_output: - return image.bitwise_right_shift(num_value_bits_input - num_value_bits_output).to(dtype) + if image.dtype == torch.uint16: + return (image / 2 ** (shift_by)).to(dtype) + else: + return image.bitwise_right_shift(shift_by).to(dtype) else: - return image.to(dtype).bitwise_left_shift_(num_value_bits_output - num_value_bits_input) + if dtype == torch.uint16: + return image.to(dtype) * 2 ** (shift_by) + else: + return image.to(dtype).bitwise_left_shift_(shift_by) # We encourage users to use to_dtype() instead but we keep this for BC