@@ -1869,6 +1869,7 @@ TensorrtExecutionProvider::GetCapability(const GraphViewer& graph,
18691869 } else if (number_of_trt_nodes == number_of_ort_nodes) {
18701870 LOGS_DEFAULT (INFO) << " [TensorRT EP] Whole graph will run on TensorRT execution provider" ;
18711871 } else {
1872+ sync_stream_after_enqueue_ = true ;
18721873 LOGS_DEFAULT (INFO) << " [TensorRT EP] Graph is partitioned and number of subgraphs running on TensorRT execution provider is " << number_of_subgraphs;
18731874 }
18741875
@@ -2387,7 +2388,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
23872388 *p = {context->allocate_func , context->release_func , context->allocator_handle , context->node_name ,
23882389 &parsers_[context->node_name ], &engines_[context->node_name ], &contexts_[context->node_name ], &builders_[context->node_name ],
23892390 &networks_[context->node_name ], input_info_[context->node_name ], output_info_[context->node_name ],
2390- input_shape_ranges_[context->node_name ], &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
2391+ input_shape_ranges_[context->node_name ], sync_stream_after_enqueue_, &tensorrt_mu_, fp16_enable_, int8_enable_, int8_calibration_cache_available_,
23912392 dla_enable_, dla_core_, &max_workspace_size_, trt_node_name_with_precision, engine_cache_enable_, cache_path_,
23922393 runtime_.get (), profiles_[context->node_name ], context_memory_sharing_enable_, &max_ctx_mem_size_,
23932394 dynamic_range_map, engine_decryption_enable_, engine_decryption_, engine_encryption_, timing_cache_enable_,
@@ -2415,6 +2416,7 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
24152416 const std::unordered_map<std::string, size_t >& input_indexes = (trt_state->input_info )[0 ];
24162417 const std::unordered_map<std::string, size_t >& output_indexes = (trt_state->output_info )[0 ];
24172418 const std::unordered_map<std::string, size_t >& output_types = (trt_state->output_info )[1 ];
2419+ bool sync_stream_after_enqueue = trt_state->sync_stream_after_enqueue ;
24182420 auto fused_node_name = trt_state->fused_node_name ;
24192421 auto & shape_ranges = trt_state->input_shape_ranges ;
24202422 auto trt_builder = trt_state->builder ->get ();
@@ -3022,6 +3024,10 @@ common::Status TensorrtExecutionProvider::Compile(const std::vector<FusedNodeAnd
30223024 return ORT_MAKE_STATUS (ONNXRUNTIME, FAIL, " TensorRT EP execution context enqueue failed." );
30233025 }
30243026
3027+ if (sync_stream_after_enqueue) {
3028+ cudaStreamSynchronize (stream);
3029+ }
3030+
30253031 // Cast INT64 input to INT32 because TensorRT doesn't fully support INT64
30263032 for (size_t i = 0 , end = output_binding_names.size (); i < end; ++i) {
30273033 const std::string& output_name = output_binding_names[i];
0 commit comments