diff --git a/CMakeLists.txt b/CMakeLists.txt index bf76d97cddf..18c269d79a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -53,7 +53,7 @@ include(CMakePackageConfigHelpers) set(TVCPP torchvision/csrc) list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops - ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu) + ${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda) if(WITH_CUDA) list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast) endif() diff --git a/setup.py b/setup.py index 0317d6e6483..33b865833f3 100644 --- a/setup.py +++ b/setup.py @@ -315,8 +315,19 @@ def get_extensions(): image_library += [jpeg_lib] image_include += [jpeg_include] + # Locating nvjpeg + # Should be included in CUDA_HOME + nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h')) + + print('NVJPEG found: {0}'.format(nvjpeg_found)) + image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))] + if nvjpeg_found: + print('Building torchvision with NVJPEG image support') + image_link_flags.append('nvjpeg') + image_path = os.path.join(extensions_dir, 'io', 'image') - image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp')) + + glob.glob(os.path.join(image_path, 'cuda', '*.cpp'))) if png_found or jpeg_found: ext_modules.append(extension( diff --git a/test/test_image.py b/test/test_image.py index ebc9a221f6d..712edba80d9 100644 --- a/test/test_image.py +++ b/test/test_image.py @@ -75,6 +75,21 @@ def test_decode_jpeg(self): with self.assertRaises(RuntimeError): decode_jpeg(torch.empty((100), dtype=torch.uint8)) + @unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable") + def test_decode_jpeg_cuda(self): + conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB] + for img_path in get_images(IMAGE_ROOT, ".jpg"): + if Image.open(img_path).mode == 'CMYK': + # not supported + continue + for mode in conversion: + data = read_file(img_path) + img_ljpeg = decode_image(data, mode=mode) + img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda') + + # Some difference expected between jpeg implementations + self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.) + def test_damaged_images(self): # Test image with bad Huffman encoding (should not raise) bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg')) diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp new file mode 100644 index 00000000000..ed194062dba --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp @@ -0,0 +1,180 @@ +#include "decode_jpeg_cuda.h" + +#include + +#if NVJPEG_FOUND +#include +#include +#include +#endif + +#include + +namespace vision { +namespace image { + +#if !NVJPEG_FOUND + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + TORCH_CHECK( + false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); +} + +#else + +static nvjpegHandle_t nvjpeg_handle = nullptr; + +void init_nvjpegImage(nvjpegImage_t& img) { + for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { + img.channel[c] = nullptr; + img.pitch[c] = 0; + } +} + +torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device) { + // Check that the input tensor dtype is uint8 + TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); + // Check that the input tensor is 1-dimensional + TORCH_CHECK( + data.dim() == 1 && data.numel() > 0, + "Expected a non empty 1-dimensional tensor"); + + TORCH_CHECK( + device.is_cuda(), "Expected a cuda device" + ) + + at::cuda::CUDAGuard device_guard(device); + + auto datap = data.data_ptr(); + + // Create nvJPEG handle + if (nvjpeg_handle == nullptr) { + nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); + + TORCH_CHECK( + create_status == NVJPEG_STATUS_SUCCESS, + "nvjpegCreateSimple failed: ", + create_status); + } + + // Create nvJPEG state (should this be persistent or not?) + nvjpegJpegState_t nvjpeg_state; + nvjpegStatus_t state_status = + nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); + + TORCH_CHECK( + state_status == NVJPEG_STATUS_SUCCESS, + "nvjpegJpegStateCreate failed: ", + state_status); + + // Get the image information + int components; + nvjpegChromaSubsampling_t subsampling; + int widths[NVJPEG_MAX_COMPONENT]; + int heights[NVJPEG_MAX_COMPONENT]; + + nvjpegStatus_t info_status = nvjpegGetImageInfo( + nvjpeg_handle, + datap, + data.numel(), + &components, + &subsampling, + widths, + heights); + + if (info_status != NVJPEG_STATUS_SUCCESS) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); + } + + if (subsampling == NVJPEG_CSS_UNKNOWN) { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); + } + + int width = widths[0]; + int height = heights[0]; + + nvjpegOutputFormat_t outputFormat; + int outputComponents; + + switch (mode) { + case IMAGE_READ_MODE_UNCHANGED: + if (components == 1) { + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + } else if (components == 3) { + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + } else { + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + break; + case IMAGE_READ_MODE_GRAY: + // This will do 0.299*R + 0.587*G + 0.114*B like opencv + // TODO check if that is the same as libjpeg + outputFormat = NVJPEG_OUTPUT_Y; + outputComponents = 1; + break; + case IMAGE_READ_MODE_RGB: + outputFormat = NVJPEG_OUTPUT_RGB; + outputComponents = 3; + break; + default: + // CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() + nvjpegJpegStateDestroy(nvjpeg_state); + TORCH_CHECK( + false, "The provided mode is not supported for JPEG files on GPU"); + } + + // nvjpegImage_t is a struct with + // - an array of pointers to each channel + // - the pitch for each channel + // which must be filled in manually + nvjpegImage_t outImage; + init_nvjpegImage(outImage); + + // TODO device selection + auto tensor = torch::empty( + {int64_t(outputComponents), int64_t(height), int64_t(width)}, + torch::dtype(torch::kU8).device(device)); + + for (int c = 0; c < outputComponents; c++) { + outImage.channel[c] = tensor[c].data_ptr(); + outImage.pitch[c] = width; + } + + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + nvjpegStatus_t decode_status = nvjpegDecode( + nvjpeg_handle, + nvjpeg_state, + datap, + data.numel(), + outputFormat, + &outImage, + stream); + + // Destroy the state + nvjpegJpegStateDestroy(nvjpeg_state); + + TORCH_CHECK( + decode_status == NVJPEG_STATUS_SUCCESS, + "nvjpegDecode failed: ", + decode_status); + + return tensor; +} + +#endif // NVJPEG_FOUND + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h new file mode 100644 index 00000000000..496b355e9b7 --- /dev/null +++ b/torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include "../image_read_mode.h" + +namespace vision { +namespace image { + +C10_EXPORT torch::Tensor decode_jpeg_cuda( + const torch::Tensor& data, + ImageReadMode mode, + torch::Device device); + +} // namespace image +} // namespace vision diff --git a/torchvision/csrc/io/image/image.cpp b/torchvision/csrc/io/image/image.cpp index 51cf9c7ce3e..37d64013cb2 100644 --- a/torchvision/csrc/io/image/image.cpp +++ b/torchvision/csrc/io/image/image.cpp @@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators() .op("image::encode_jpeg", &encode_jpeg) .op("image::read_file", &read_file) .op("image::write_file", &write_file) - .op("image::decode_image", &decode_image); + .op("image::decode_image", &decode_image) + .op("image::decode_jpeg_cuda", &decode_jpeg_cuda); } // namespace image } // namespace vision diff --git a/torchvision/csrc/io/image/image.h b/torchvision/csrc/io/image/image.h index fb09d6d71b8..05bac44c77d 100644 --- a/torchvision/csrc/io/image/image.h +++ b/torchvision/csrc/io/image/image.h @@ -6,3 +6,4 @@ #include "cpu/encode_jpeg.h" #include "cpu/encode_png.h" #include "cpu/read_write_file.h" +#include "cuda/decode_jpeg_cuda.h" diff --git a/torchvision/io/image.py b/torchvision/io/image.py index e193555e447..0c14ae007b4 100644 --- a/torchvision/io/image.py +++ b/torchvision/io/image.py @@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): write_file(filename, output) -def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: +def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, + device: torch.device = 'cpu') -> torch.Tensor: """ Decodes a JPEG image into a 3 dimensional RGB Tensor. Optionally converts the image to the desired format. @@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG Returns: output (Tensor[image_channels, image_height, image_width]) """ - output = torch.ops.image.decode_jpeg(input, mode.value) + device = torch.device(device) + if device.type == 'cuda': + output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) + else: + output = torch.ops.image.decode_jpeg(input, mode.value) return output