Skip to content

Commit 2f57625

Browse files
authored
[TensorRT EP] Add stream sync after enqueue (#18026)
If the model is partitioned into TRT subgraphs and CUDA EP node, we observed cuda stream synchronization issue when multithreading. Calling stream sync API after enqueue can solve this issue without adding much performance overhead.
1 parent 020824e commit 2f57625

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
18691869
} else if (number_of_trt_nodes == number_of_ort_nodes) {
18701870
LOGS_DEFAULT(INFO) << "[TensorRT EP] Whole graph will run on TensorRT execution provider";
18711871
} else {
1872+
sync_stream_after_enqueue_ = true;
18721873
LOGS_DEFAULT(INFO) << "[TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
18731874
}
18741875

@@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
23872388
*p = {context->allocate_func, context->release_func, context->allocator_handle, context->node_name,
23882389
&parsers_[context->node_name], &engines_[context->node_name], &contexts_[context->node_name], &builders_[context->node_name],
23892390
&networks_[context->node_name], input_info_[context->node_name], output_info_[context->node_name],
2390-
input_shape_ranges_[context->node_name], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
2391+
input_shape_ranges_[context->node_name], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
23912392
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
23922393
runtime_.get(), profiles_[context->node_name], context_memory_sharing_enable_, &max_ctx_mem_size_,
23932394
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
@@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
24152416
const std::unordered_map<std::string, size_t>& input_indexes = (trt_state->input_info)[0];
24162417
const std::unordered_map<std::string, size_t>& output_indexes = (trt_state->output_info)[0];
24172418
const std::unordered_map<std::string, size_t>& output_types = (trt_state->output_info)[1];
2419+
bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue;
24182420
auto fused_node_name = trt_state->fused_node_name;
24192421
auto& shape_ranges = trt_state->input_shape_ranges;
24202422
auto trt_builder = trt_state->builder->get();
@@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
30223024
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "TensorRT EP execution context enqueue failed.");
30233025
}
30243026

3027+
if (sync_stream_after_enqueue) {
3028+
cudaStreamSynchronize(stream);
3029+
}
3030+
30253031
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
30263032
for (size_t i = 0, end = output_binding_names.size(); i < end; ++i) {
30273033
const std::string& output_name = output_binding_names[i];

onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ struct TensorrtFuncState {
111111
std::vector<std::unordered_map<std::string, size_t>> input_info;
112112
std::vector<std::unordered_map<std::string, size_t>> output_info;
113113
std::unordered_map<std::string, std::unordered_map<size_t, std::vector<std::vector<int64_t>>>> input_shape_ranges;
114+
bool sync_stream_after_enqueue = false;
114115
OrtMutex* tensorrt_mu_ptr = nullptr;
115116
bool fp16_enable = false;
116117
bool int8_enable = false;
@@ -262,6 +263,9 @@ class TensorrtExecutionProvider : public IExecutionProvider {
262263
cudnnHandle_t external_cudnn_handle_ = nullptr;
263264
cublasHandle_t external_cublas_handle_ = nullptr;
264265

266+
// Call cudaStreamSynchronize() after TRT enqueueV2()/enqueueV3()
267+
mutable bool sync_stream_after_enqueue_ = false;
268+
265269
CUDAGraph cuda_graph_;
266270
bool is_graph_captured_ = false;
267271
int regular_run_count_before_graph_capture_ = 0;

0 commit comments

Comments
 (0)