-
Notifications
You must be signed in to change notification settings - Fork 7.1k
[WIP] nvJPEG support #2786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP] nvJPEG support #2786
Changes from all commits
f878b36
5eb6d73
afd4a2e
8abe4a5
8ae0751
9a2510f
ac3330b
e798157
a07f53a
e485656
1c1e471
3e7486e
5bc5e21
4d4cd45
dd3e445
f560eeb
ab90893
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#include "decode_jpeg_cuda.h" | ||
|
||
#include <ATen/ATen.h> | ||
|
||
#if NVJPEG_FOUND | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <c10/cuda/CUDAGuard.h> | ||
#include <nvjpeg.h> | ||
#endif | ||
|
||
#include <string> | ||
|
||
namespace vision { | ||
namespace image { | ||
|
||
#if !NVJPEG_FOUND | ||
|
||
torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device) { | ||
TORCH_CHECK( | ||
false, "decode_jpeg_cuda: torchvision not compiled with nvJPEG support"); | ||
} | ||
|
||
#else | ||
|
||
static nvjpegHandle_t nvjpeg_handle = nullptr; | ||
|
||
void init_nvjpegImage(nvjpegImage_t& img) { | ||
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) { | ||
img.channel[c] = nullptr; | ||
img.pitch[c] = 0; | ||
} | ||
} | ||
|
||
torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device) { | ||
// Check that the input tensor dtype is uint8 | ||
TORCH_CHECK(data.dtype() == torch::kU8, "Expected a torch.uint8 tensor"); | ||
// Check that the input tensor is 1-dimensional | ||
TORCH_CHECK( | ||
data.dim() == 1 && data.numel() > 0, | ||
"Expected a non empty 1-dimensional tensor"); | ||
|
||
TORCH_CHECK( | ||
device.is_cuda(), "Expected a cuda device" | ||
) | ||
|
||
at::cuda::CUDAGuard device_guard(device); | ||
|
||
auto datap = data.data_ptr<uint8_t>(); | ||
|
||
// Create nvJPEG handle | ||
if (nvjpeg_handle == nullptr) { | ||
nvjpegStatus_t create_status = nvjpegCreateSimple(&nvjpeg_handle); | ||
|
||
TORCH_CHECK( | ||
create_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegCreateSimple failed: ", | ||
create_status); | ||
} | ||
|
||
// Create nvJPEG state (should this be persistent or not?) | ||
nvjpegJpegState_t nvjpeg_state; | ||
nvjpegStatus_t state_status = | ||
nvjpegJpegStateCreate(nvjpeg_handle, &nvjpeg_state); | ||
|
||
TORCH_CHECK( | ||
state_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegJpegStateCreate failed: ", | ||
state_status); | ||
|
||
// Get the image information | ||
int components; | ||
nvjpegChromaSubsampling_t subsampling; | ||
int widths[NVJPEG_MAX_COMPONENT]; | ||
int heights[NVJPEG_MAX_COMPONENT]; | ||
|
||
nvjpegStatus_t info_status = nvjpegGetImageInfo( | ||
nvjpeg_handle, | ||
datap, | ||
data.numel(), | ||
&components, | ||
&subsampling, | ||
widths, | ||
heights); | ||
|
||
if (info_status != NVJPEG_STATUS_SUCCESS) { | ||
nvjpegJpegStateDestroy(nvjpeg_state); | ||
TORCH_CHECK(false, "nvjpegGetImageInfo failed: ", info_status); | ||
} | ||
|
||
if (subsampling == NVJPEG_CSS_UNKNOWN) { | ||
nvjpegJpegStateDestroy(nvjpeg_state); | ||
TORCH_CHECK(false, "Unknown NVJPEG chroma subsampling"); | ||
} | ||
|
||
int width = widths[0]; | ||
int height = heights[0]; | ||
|
||
nvjpegOutputFormat_t outputFormat; | ||
int outputComponents; | ||
|
||
switch (mode) { | ||
case IMAGE_READ_MODE_UNCHANGED: | ||
if (components == 1) { | ||
outputFormat = NVJPEG_OUTPUT_Y; | ||
outputComponents = 1; | ||
} else if (components == 3) { | ||
outputFormat = NVJPEG_OUTPUT_RGB; | ||
outputComponents = 3; | ||
} else { | ||
nvjpegJpegStateDestroy(nvjpeg_state); | ||
TORCH_CHECK( | ||
false, "The provided mode is not supported for JPEG files on GPU"); | ||
} | ||
break; | ||
case IMAGE_READ_MODE_GRAY: | ||
// This will do 0.299*R + 0.587*G + 0.114*B like opencv | ||
// TODO check if that is the same as libjpeg | ||
outputFormat = NVJPEG_OUTPUT_Y; | ||
outputComponents = 1; | ||
break; | ||
case IMAGE_READ_MODE_RGB: | ||
outputFormat = NVJPEG_OUTPUT_RGB; | ||
outputComponents = 3; | ||
break; | ||
default: | ||
// CMYK as input might work with nvjpegDecodeParamsSetAllowCMYK() | ||
nvjpegJpegStateDestroy(nvjpeg_state); | ||
TORCH_CHECK( | ||
false, "The provided mode is not supported for JPEG files on GPU"); | ||
} | ||
|
||
// nvjpegImage_t is a struct with | ||
// - an array of pointers to each channel | ||
// - the pitch for each channel | ||
// which must be filled in manually | ||
nvjpegImage_t outImage; | ||
init_nvjpegImage(outImage); | ||
|
||
// TODO device selection | ||
auto tensor = torch::empty( | ||
{int64_t(outputComponents), int64_t(height), int64_t(width)}, | ||
torch::dtype(torch::kU8).device(device)); | ||
|
||
for (int c = 0; c < outputComponents; c++) { | ||
outImage.channel[c] = tensor[c].data_ptr<uint8_t>(); | ||
outImage.pitch[c] = width; | ||
} | ||
|
||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
|
||
nvjpegStatus_t decode_status = nvjpegDecode( | ||
nvjpeg_handle, | ||
nvjpeg_state, | ||
datap, | ||
data.numel(), | ||
outputFormat, | ||
&outImage, | ||
stream); | ||
|
||
// Destroy the state | ||
nvjpegJpegStateDestroy(nvjpeg_state); | ||
|
||
TORCH_CHECK( | ||
decode_status == NVJPEG_STATUS_SUCCESS, | ||
"nvjpegDecode failed: ", | ||
decode_status); | ||
|
||
return tensor; | ||
} | ||
|
||
#endif // NVJPEG_FOUND | ||
|
||
} // namespace image | ||
} // namespace vision |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
#pragma once | ||
|
||
#include <torch/types.h> | ||
#include "../image_read_mode.h" | ||
|
||
namespace vision { | ||
namespace image { | ||
|
||
C10_EXPORT torch::Tensor decode_jpeg_cuda( | ||
const torch::Tensor& data, | ||
ImageReadMode mode, | ||
torch::Device device); | ||
|
||
} // namespace image | ||
} // namespace vision |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -149,7 +149,8 @@ def write_png(input: torch.Tensor, filename: str, compression_level: int = 6): | |
write_file(filename, output) | ||
|
||
|
||
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED) -> torch.Tensor: | ||
def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANGED, | ||
device: torch.device = 'cpu') -> torch.Tensor: | ||
""" | ||
Decodes a JPEG image into a 3 dimensional RGB Tensor. | ||
Optionally converts the image to the desired format. | ||
|
@@ -166,7 +167,11 @@ def decode_jpeg(input: torch.Tensor, mode: ImageReadMode = ImageReadMode.UNCHANG | |
Returns: | ||
output (Tensor[image_channels, image_height, image_width]) | ||
""" | ||
output = torch.ops.image.decode_jpeg(input, mode.value) | ||
device = torch.device(device) | ||
if device.type == 'cuda': | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should the dispatching be done here or in C++? How will it interact with torchscript? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was thinking about making it dispatch on C++ directly, but this should also work with torchscript I believe |
||
output = torch.ops.image.decode_jpeg_cuda(input, mode.value, device) | ||
else: | ||
output = torch.ops.image.decode_jpeg(input, mode.value) | ||
return output | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should annotation be
Union[int, str, torch.device]
? egThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would just make it a
torch.device
, because I think torchscript doesn't yet support Union