Skip to content

[WIP] nvJPEG support #2786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 17 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ include(CMakePackageConfigHelpers)

set(TVCPP torchvision/csrc)
list(APPEND ALLOW_LISTED ${TVCPP} ${TVCPP}/io/image ${TVCPP}/io/image/cpu ${TVCPP}/models ${TVCPP}/ops
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu)
${TVCPP}/ops/autograd ${TVCPP}/ops/cpu ${TVCPP}/io/image/cuda)
if(WITH_CUDA)
list(APPEND ALLOW_LISTED ${TVCPP}/ops/cuda ${TVCPP}/ops/autocast)
endif()
Expand Down
13 changes: 12 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -315,8 +315,19 @@ def get_extensions():
image_library += [jpeg_lib]
image_include += [jpeg_include]

# Locating nvjpeg
# Should be included in CUDA_HOME
nvjpeg_found = extension is CUDAExtension and os.path.exists(os.path.join(CUDA_HOME, 'include', 'nvjpeg.h'))

print('NVJPEG found: {0}'.format(nvjpeg_found))
image_macros += [('NVJPEG_FOUND', str(int(nvjpeg_found)))]
if nvjpeg_found:
print('Building torchvision with NVJPEG image support')
image_link_flags.append('nvjpeg')

image_path = os.path.join(extensions_dir, 'io', 'image')
image_src = glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
image_src = (glob.glob(os.path.join(image_path, '*.cpp')) + glob.glob(os.path.join(image_path, 'cpu', '*.cpp'))
+ glob.glob(os.path.join(image_path, 'cuda', '*.cpp')))

if png_found or jpeg_found:
ext_modules.append(extension(
Expand Down
15 changes: 15 additions & 0 deletions test/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,21 @@ def test_decode_jpeg(self):
with self.assertRaises(RuntimeError):
decode_jpeg(torch.empty((100), dtype=torch.uint8))

@unittest.skipIf(not torch.cuda.is_available(), "CUDA unavailable")
def test_decode_jpeg_cuda(self):
conversion = [ImageReadMode.UNCHANGED, ImageReadMode.GRAY, ImageReadMode.RGB]
for img_path in get_images(IMAGE_ROOT, ".jpg"):
if Image.open(img_path).mode == 'CMYK':
# not supported
continue
for mode in conversion:
data = read_file(img_path)
img_ljpeg = decode_image(data, mode=mode)
img_nvjpeg = torch.ops.image.decode_jpeg_cuda(data, mode.value, 'cuda')

# Some difference expected between jpeg implementations
self.assertTrue((img_ljpeg.float() - img_nvjpeg.cpu().float()).abs().mean() < 2.)

def test_damaged_images(self):
# Test image with bad Huffman encoding (should not raise)
bad_huff = read_file(os.path.join(DAMAGED_JPEG, 'bad_huffman.jpg'))
Expand Down
180 changes: 180 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#include "decode_jpeg_cuda.h"

#include <ATen/ATen.h>

#if NVJPEG_FOUND
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <nvjpeg.h>
#endif

#include <string>

namespace vision {
namespace image {

#if !NVJPEG_FOUND

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
TORCH_CHECK(
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support");
}

#else

static nvjpegHandle_t nvjpeg_handle = nullptr;

void init_nvjpegImage(nvjpegImage_t& img) {
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
img.channel[c] = nullptr;
img.pitch[c] = 0;
}
}

torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device) {
// Check that the input tensor dtype is uint8
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor");
// Check that the input tensor is 1-dimensional
TORCH_CHECK(
data.dim() == 1 && data.numel() > 0,
"Expected a non empty 1-dimensional tensor");

TORCH_CHECK(
device.is_cuda(), "Expected a cuda device"
)

at::cuda::CUDAGuard device_guard(device);

auto datap = data.data_ptr<uint8_t>();

// Create nvJPEG handle
if (nvjpeg_handle == nullptr) {
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle);

TORCH_CHECK(
create_status == NVJPEG_STATUS_SUCCESS,
"nvjpegCreateSimple failed: ",
create_status);
}

// Create nvJPEG state (should this be persistent or not?)
nvjpegJpegState_t nvjpeg_state;
nvjpegStatus_t state_status =
nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state);

TORCH_CHECK(
state_status == NVJPEG_STATUS_SUCCESS,
"nvjpegJpegStateCreate failed: ",
state_status);

// Get the image information
int components;
nvjpegChromaSubsampling_t subsampling;
int widths[NVJPEG_MAX_COMPONENT];
int heights[NVJPEG_MAX_COMPONENT];

nvjpegStatus_t info_status = nvjpegGetImageInfo(
nvjpeg_handle,
datap,
data.numel(),
&components,
&subsampling,
widths,
heights);

if (info_status != NVJPEG_STATUS_SUCCESS) {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status);
}

if (subsampling == NVJPEG_CSS_UNKNOWN) {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling");
}

int width = widths[0];
int height = heights[0];

nvjpegOutputFormat_t outputFormat;
int outputComponents;

switch (mode) {
case IMAGE_READ_MODE_UNCHANGED:
if (components == 1) {
outputFormat = NVJPEG_OUTPUT_Y;
outputComponents = 1;
} else if (components == 3) {
outputFormat = NVJPEG_OUTPUT_RGB;
outputComponents = 3;
} else {
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG files on GPU");
}
break;
case IMAGE_READ_MODE_GRAY:
// This will do 0.299*R + 0.587*G + 0.114*B like opencv
// TODO check if that is the same as libjpeg
outputFormat = NVJPEG_OUTPUT_Y;
outputComponents = 1;
break;
case IMAGE_READ_MODE_RGB:
outputFormat = NVJPEG_OUTPUT_RGB;
outputComponents = 3;
break;
default:
// CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK()
nvjpegJpegStateDestroy(nvjpeg_state);
TORCH_CHECK(
false, "The provided mode is not supported for JPEG files on GPU");
}

// nvjpegImage_t is a struct with
// - an array of pointers to each channel
// - the pitch for each channel
// which must be filled in manually
nvjpegImage_t outImage;
init_nvjpegImage(outImage);

// TODO device selection
auto tensor = torch::empty(
{int64_t(outputComponents), int64_t(height), int64_t(width)},
torch::dtype(torch::kU8).device(device));

for (int c = 0; c < outputComponents; c++) {
outImage.channel[c] = tensor[c].data_ptr<uint8_t>();
outImage.pitch[c] = width;
}

cudaStream_t stream = at::cuda::getCurrentCUDAStream();

nvjpegStatus_t decode_status = nvjpegDecode(
nvjpeg_handle,
nvjpeg_state,
datap,
data.numel(),
outputFormat,
&outImage,
stream);

// Destroy the state
nvjpegJpegStateDestroy(nvjpeg_state);

TORCH_CHECK(
decode_status == NVJPEG_STATUS_SUCCESS,
"nvjpegDecode failed: ",
decode_status);

return tensor;
}

#endif // NVJPEG_FOUND

} // namespace image
} // namespace vision
15 changes: 15 additions & 0 deletions torchvision/csrc/io/image/cuda/decode_jpeg_cuda.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <torch/types.h>
#include "../image_read_mode.h"

namespace vision {
namespace image {

C10_EXPORT torch::Tensor decode_jpeg_cuda(
const torch::Tensor& data,
ImageReadMode mode,
torch::Device device);

} // namespace image
} // namespace vision
3 changes: 2 additions & 1 deletion torchvision/csrc/io/image/image.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@ static auto registry = torch::RegisterOperators()
.op("image::encode_jpeg", &encode_jpeg)
.op("image::read_file", &read_file)
.op("image::write_file", &write_file)
.op("image::decode_image", &decode_image);
.op("image::decode_image", &decode_image)
.op("image::decode_jpeg_cuda", &decode_jpeg_cuda);

} // namespace image
} // namespace vision
1 change: 1 addition & 0 deletions torchvision/csrc/io/image/image.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@
#include "cpu/encode_jpeg.h"
#include "cpu/encode_png.h"
#include "cpu/read_write_file.h"
#include "cuda/decode_jpeg_cuda.h"
9 changes: 7 additions & 2 deletions torchvision/io/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6):
write_file(filename, output)


def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor:
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED,
device: torch.device = 'cpu') -> torch.Tensor:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should annotation be Union[int, str, torch.device]? eg

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just make it a torch.device, because I think torchscript doesn't yet support Union

"""
Decodes a JPEG image into a 3 dimensional RGB Tensor.
Optionally converts the image to the desired format.
Expand All @@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG
Returns:
output (Tensor[image_channels, image_height, image_width])
"""
output = torch.ops.image.decode_jpeg(input, mode.value)
device = torch.device(device)
if device.type == 'cuda':
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the dispatching be done here or in C++? How will it interact with torchscript?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about making it dispatch on C++ directly, but this should also work with torchscript I believe

output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device)
else:
output = torch.ops.image.decode_jpeg(input, mode.value)
return output


Expand Down