Skip to content

Commit 1c1e471

Browse files
committed
Use at::cuda::getCurrentCUDAStream()
1 parent e485656 commit 1c1e471

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

torchvision/csrc/io/image/cuda/readjpeg_cuda.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,15 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
1111

1212
#else
1313

14+
#include <ATen/ATen.h>
15+
#include <ATen/cuda/CUDAContext.h>
1416
#include <nvjpeg.h>
1517

1618
static nvjpegHandle_t nvjpeg_handle = nullptr;
1719

1820
void init_nvjpegImage(nvjpegImage_t& img) {
1921
for (int c = 0; c < NVJPEG_MAX_COMPONENT; c++) {
20-
img.channel[c] = NULL;
22+
img.channel[c] = nullptr;
2123
img.pitch[c] = 0;
2224
}
2325
}
@@ -131,16 +133,16 @@ torch::Tensor decodeJPEG_cuda(const torch::Tensor& data, ImageReadMode mode) {
131133
outImage.pitch[c] = width;
132134
}
133135

134-
// TODO torch cuda stream support
135-
// TODO output besides RGB
136+
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
137+
136138
nvjpegStatus_t decode_status = nvjpegDecode(
137139
nvjpeg_handle,
138140
nvjpeg_state,
139141
datap,
140142
data.numel(),
141143
outputFormat,
142144
&outImage,
143-
/*stream=*/0);
145+
stream);
144146

145147
// Destroy the state
146148
nvjpegJpegStateDestroy(nvjpeg_state);

0 commit comments

Comments
 (0)