Skip to content

Commit f9fbc10

Browse files
authored
Allow cuda device to be passed without the index for GPU decoding (#5505)
1 parent d4146ef commit f9fbc10

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

test/test_video_gpu_decoder.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class TestVideoGPUDecoder:
3030
)
3131
def test_frame_reading(self, video_file):
3232
full_path = os.path.join(VIDEO_DIR, video_file)
33-
decoder = VideoReader(full_path, device="cuda:0")
33+
decoder = VideoReader(full_path, device="cuda")
3434
with av.open(full_path) as container:
3535
for av_frame in container.decode(container.streams.video[0]):
3636
av_frames = torch.tensor(av_frame.to_rgb(src_colorspace="ITU709").to_ndarray())
@@ -54,7 +54,7 @@ def test_frame_reading(self, video_file):
5454
],
5555
)
5656
def test_seek_reading(self, keyframes, full_path, duration):
57-
decoder = VideoReader(full_path, device="cuda:0")
57+
decoder = VideoReader(full_path, device="cuda")
5858
time = duration / 2
5959
decoder.seek(time, keyframes_only=keyframes)
6060
with av.open(full_path) as container:
@@ -80,7 +80,7 @@ def test_seek_reading(self, keyframes, full_path, duration):
8080
)
8181
def test_metadata(self, video_file):
8282
full_path = os.path.join(VIDEO_DIR, video_file)
83-
decoder = VideoReader(full_path, device="cuda:0")
83+
decoder = VideoReader(full_path, device="cuda")
8484
video_metadata = decoder.get_metadata()["video"]
8585
with av.open(full_path) as container:
8686
video = container.streams.video[0]

torchvision/csrc/io/decoder/gpu/gpu_decoder.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33

44
/* Set cuda device, create cuda context and initialise the demuxer and decoder.
55
*/
6-
GPUDecoder::GPUDecoder(std::string src_file, int64_t dev)
7-
: demuxer(src_file.c_str()), device(dev) {
8-
at::cuda::CUDAGuard device_guard(device);
6+
GPUDecoder::GPUDecoder(std::string src_file, torch::Device dev)
7+
: demuxer(src_file.c_str()) {
8+
at::cuda::CUDAGuard device_guard(dev);
9+
device = device_guard.current_device().index();
910
check_for_cuda_errors(
1011
cuDevicePrimaryCtxRetain(&ctx, device), __LINE__, __FILE__);
1112
decoder.init(ctx, ffmpeg_to_codec(demuxer.get_video_codec()));
@@ -58,7 +59,7 @@ c10::Dict<std::string, c10::Dict<std::string, double>> GPUDecoder::
5859

5960
TORCH_LIBRARY(torchvision, m) {
6061
m.class_<GPUDecoder>("GPUDecoder")
61-
.def(torch::init<std::string, int64_t>())
62+
.def(torch::init<std::string, torch::Device>())
6263
.def("seek", &GPUDecoder::seek)
6364
.def("get_metadata", &GPUDecoder::get_metadata)
6465
.def("next", &GPUDecoder::decode);

torchvision/csrc/io/decoder/gpu/gpu_decoder.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
class GPUDecoder : public torch::CustomClassHolder {
77
public:
8-
GPUDecoder(std::string, int64_t);
8+
GPUDecoder(std::string, torch::Device);
99
~GPUDecoder();
1010
torch::Tensor decode();
1111
void seek(double, bool);

torchvision/io/video_reader.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ class VideoReader:
8484
will depend on the version of FFMPEG codecs supported.
8585
8686
device (str, optional): Device to be used for decoding. Defaults to ``"cpu"``.
87+
To use GPU decoding, pass ``device="cuda"``.
8788
8889
"""
8990

@@ -95,9 +96,7 @@ def __init__(self, path: str, stream: str = "video", num_threads: int = 0, devic
9596
if not _HAS_GPU_VIDEO_DECODER:
9697
raise RuntimeError("Not compiled with GPU decoder support.")
9798
self.is_cuda = True
98-
if device.index is None:
99-
raise RuntimeError("Invalid cuda device!")
100-
self._c = torch.classes.torchvision.GPUDecoder(path, device.index)
99+
self._c = torch.classes.torchvision.GPUDecoder(path, device)
101100
return
102101
if not _has_video_opt():
103102
raise RuntimeError(

0 commit comments

Comments
 (0)