Skip to content

Commit 3f8fff7

Browse files
authored
feat: Cudagraphs integration for Torch-TRT + Non-default Stream Utilization (#2881)
1 parent 29d2e43 commit 3f8fff7

File tree

16 files changed

+812
-207
lines changed

16 files changed

+812
-207
lines changed

WORKSPACE

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,8 +101,6 @@ http_archive(
101101
],
102102
)
103103

104-
105-
106104
####################################################################################
107105
# Locally installed dependencies (use in cases of custom dependencies or aarch64)
108106
####################################################################################

core/runtime/TRTEngine.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <cuda_runtime.h>
44
#include "NvInfer.h"
5+
#include "c10/cuda/CUDAStream.h"
56
#include "torch/csrc/jit/frontend/function_schema_parser.h"
67
#include "torch/cuda.h"
78

@@ -70,6 +71,15 @@ TRTEngine::TRTEngine(
7071
multi_gpu_device_check();
7172
set_rt_device(device_info);
7273

74+
// Set active stream to non-default stream
75+
auto current_stream = c10::cuda::getCurrentCUDAStream(device_info.id);
76+
if (current_stream == c10::cuda::getDefaultCUDAStream(device_info.id)) {
77+
active_stream = c10::cuda::getStreamFromPool(false, device_info.id);
78+
c10::cuda::setCurrentCUDAStream(active_stream);
79+
} else {
80+
active_stream = current_stream;
81+
}
82+
7383
rt = make_trt(nvinfer1::createInferRuntime(util::logging::get_logger()));
7484

7585
name = slugify(mod_name);
@@ -112,7 +122,9 @@ TRTEngine::TRTEngine(
112122

113123
num_io = std::make_pair(inputs, outputs);
114124
in_binding_names.resize(inputs);
125+
input_buffers.resize(inputs);
115126
out_binding_names.resize(outputs);
127+
output_buffers.resize(outputs);
116128
for (int64_t x = 0; x < cuda_engine->getNbIOTensors(); x++) {
117129
std::string bind_name = cuda_engine->getIOTensorName(x);
118130
if (cuda_engine->getTensorIOMode(bind_name.c_str()) == nvinfer1::TensorIOMode::kINPUT) {
@@ -124,6 +136,7 @@ TRTEngine::TRTEngine(
124136
} else {
125137
uint64_t inputs_size = _in_binding_names.size();
126138
in_binding_names.resize(inputs_size);
139+
input_buffers.resize(inputs_size);
127140
for (uint64_t pyt_idx = 0; pyt_idx < inputs_size; pyt_idx++) {
128141
auto binding_name = _in_binding_names[pyt_idx];
129142
// Check if the binding name provided is in the list of engine's bindings
@@ -153,6 +166,7 @@ TRTEngine::TRTEngine(
153166

154167
uint64_t outputs = _out_binding_names.size();
155168
out_binding_names.resize(outputs);
169+
output_buffers.resize(outputs);
156170
for (size_t pyt_idx = 0; pyt_idx < outputs; pyt_idx++) {
157171
auto binding_name = _out_binding_names[pyt_idx];
158172
// Check if the binding name provided is in the list of engine's bindings

core/runtime/TRTEngine.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
#include <utility>
88

99
#include "ATen/core/function_schema.h"
10+
#include "ATen/cuda/CUDAGraph.h"
1011
#include "NvInfer.h"
12+
#include "c10/cuda/CUDAStream.h"
1113
#include "torch/custom_class.h"
1214

1315
#include "core/runtime/TRTEngineProfiler.h"
@@ -65,6 +67,14 @@ struct TRTEngine : torch::CustomClassHolder {
6567
void dump_engine_layer_info();
6668
friend std::ostream& operator<<(std::ostream& os, const TRTEngine& engine);
6769
static const char BINDING_DELIM = '%';
70+
71+
// CUDAGraph-Related Functionality
72+
at::cuda::CUDAGraph cudagraph = {};
73+
at::cuda::CUDAStream active_stream = c10::cuda::getDefaultCUDAStream();
74+
std::vector<at::Tensor> input_buffers = {};
75+
std::vector<at::Tensor> output_buffers = {};
76+
std::string shape_key;
77+
6878
// TODO: Implement a call method
6979
// c10::List<at::Tensor> Run(c10::List<at::Tensor> inputs);
7080

0 commit comments

Comments
 (0)