@@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
1869
1869
} else if (number_of_trt_nodes == number_of_ort_nodes) {
1870
1870
LOGS_DEFAULT (INFO) << " [TensorRT EP] Whole graph will run on TensorRT execution provider" ;
1871
1871
} else {
1872
+ sync_stream_after_enqueue_ = true ;
1872
1873
LOGS_DEFAULT (INFO) << " [TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
1873
1874
}
1874
1875
@@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
2387
2388
*p = {context->allocate_func , context->release_func , context->allocator_handle , context->node_name ,
2388
2389
&parsers_[context->node_name ], &engines_[context->node_name ], &contexts_[context->node_name ], &builders_[context->node_name ],
2389
2390
&networks_[context->node_name ], input_info_[context->node_name ], output_info_[context->node_name ],
2390
- input_shape_ranges_[context->node_name ], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
2391
+ input_shape_ranges_[context->node_name ], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
2391
2392
dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
2392
2393
runtime_.get (), profiles_[context->node_name ], context_memory_sharing_enable_, &max_ctx_mem_size_,
2393
2394
dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
@@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
2415
2416
const std::unordered_map<std::string, size_t >& input_indexes = (trt_state->input_info )[0 ];
2416
2417
const std::unordered_map<std::string, size_t >& output_indexes = (trt_state->output_info )[0 ];
2417
2418
const std::unordered_map<std::string, size_t >& output_types = (trt_state->output_info )[1 ];
2419
+ bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue ;
2418
2420
auto fused_node_name = trt_state->fused_node_name ;
2419
2421
auto & shape_ranges = trt_state->input_shape_ranges ;
2420
2422
auto trt_builder = trt_state->builder ->get ();
@@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
3022
3024
return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " TensorRT EP execution context enqueue failed." );
3023
3025
}
3024
3026
3027
+ if (sync_stream_after_enqueue) {
3028
+ cudaStreamSynchronize (stream);
3029
+ }
3030
+
3025
3031
// Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
3026
3032
for (size_t i = 0 , end = output_binding_names.size (); i < end; ++i) {
3027
3033
const std::string& output_name = output_binding_names[i];
0 commit comments