2
2
3
3
#include < cuda_runtime.h>
4
4
#include " NvInfer.h"
5
+ #include " c10/cuda/CUDAStream.h"
5
6
#include " torch/csrc/jit/frontend/function_schema_parser.h"
6
7
#include " torch/cuda.h"
7
8
@@ -70,6 +71,15 @@ TRTEngine::TRTEngine(
70
71
multi_gpu_device_check ();
71
72
set_rt_device (device_info);
72
73
74
+ // Set active stream to non-default stream
75
+ auto current_stream = c10::cuda::getCurrentCUDAStream (device_info.id );
76
+ if (current_stream == c10::cuda::getDefaultCUDAStream (device_info.id )) {
77
+ active_stream = c10::cuda::getStreamFromPool (false , device_info.id );
78
+ c10::cuda::setCurrentCUDAStream (active_stream);
79
+ } else {
80
+ active_stream = current_stream;
81
+ }
82
+
73
83
rt = make_trt (nvinfer1::createInferRuntime (util::logging::get_logger ()));
74
84
75
85
name = slugify (mod_name);
@@ -112,7 +122,9 @@ TRTEngine::TRTEngine(
112
122
113
123
num_io = std::make_pair (inputs, outputs);
114
124
in_binding_names.resize (inputs);
125
+ input_buffers.resize (inputs);
115
126
out_binding_names.resize (outputs);
127
+ output_buffers.resize (outputs);
116
128
for (int64_t x = 0 ; x < cuda_engine->getNbIOTensors (); x++) {
117
129
std::string bind_name = cuda_engine->getIOTensorName (x);
118
130
if (cuda_engine->getTensorIOMode (bind_name.c_str ()) == nvinfer1::TensorIOMode::kINPUT ) {
@@ -124,6 +136,7 @@ TRTEngine::TRTEngine(
124
136
} else {
125
137
uint64_t inputs_size = _in_binding_names.size ();
126
138
in_binding_names.resize (inputs_size);
139
+ input_buffers.resize (inputs_size);
127
140
for (uint64_t pyt_idx = 0 ; pyt_idx < inputs_size; pyt_idx++) {
128
141
auto binding_name = _in_binding_names[pyt_idx];
129
142
// Check if the binding name provided is in the list of engine's bindings
@@ -153,6 +166,7 @@ TRTEngine::TRTEngine(
153
166
154
167
uint64_t outputs = _out_binding_names.size ();
155
168
out_binding_names.resize (outputs);
169
+ output_buffers.resize (outputs);
156
170
for (size_t pyt_idx = 0 ; pyt_idx < outputs; pyt_idx++) {
157
171
auto binding_name = _out_binding_names[pyt_idx];
158
172
// Check if the binding name provided is in the list of engine's bindings
0 commit comments