Skip to content

Commit e7d410b

Browse files
tianleiwukleiti
authored andcommitted
[CUDA EP] Add warning logs when adding memcpy nodes (microsoft#18032)
Memcpy nodes could have negative impact on performance, they also cause ORT unable to run CUDA graph. Here we add a warning log for CUDA EP when this happens. It could help trouble shooting. For example, when CUDA graph cannot run, we can see the logs to find out where the Memcpy nodes are inserted (Although it is also possible through saving optimized model, but that need more time and disk space). Note that the warning is per graph. When there are subgraphs, we might see multiple warnings if the issue happens in multiple graphs. Example logs: ``` 2023-10-19 20:58:10.678176531 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after input_ids for CUDAExecutionProvider 2023-10-19 20:58:10.678198702 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after /text_model/ArgMax_output_0 for CUDAExecutionProvider 2023-10-19 20:58:10.678211727 [I:onnxruntime:, transformer_memcpy.cc:329 AddCopyNode] Add MemcpyFromHost after /text_model/Gather_3_output_0 for CUDAExecutionProvider 2023-10-19 20:58:10.678257903 [W:onnxruntime:, transformer_memcpy.cc:74 ApplyImpl] 3 Memcpy nodes are added to the graph main_graph for CUDAExecutionProvider. It might have negative impact on performance (including unable to run CUDA graph). Set session_options.log_severity_level=1 to see the detail logs before this message. ```
1 parent dcf27f0 commit e7d410b

File tree

1 file changed

+30
-10
lines changed

1 file changed

+30
-10
lines changed

onnxruntime/core/optimizer/transformer_memcpy.cc

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// Licensed under the MIT License.
33

44
#include "transformer_memcpy.h"
5+
#include "core/common/logging/logging.h"
56
#include "core/framework/kernel_registry_manager.h"
67
#include "core/framework/execution_providers.h"
78
#include "core/framework/utils.h"
@@ -16,12 +17,12 @@ class TransformerMemcpyImpl {
1617
TransformerMemcpyImpl(onnxruntime::Graph& graph, const std::string& provider)
1718
: graph_(graph), provider_(provider) {}
1819

19-
bool ModifyGraph(const KernelRegistryManager& schema_registries);
20+
bool ModifyGraph(const KernelRegistryManager& schema_registries, const logging::Logger& logger, int& copy_node_counter);
2021

2122
private:
2223
void ProcessDefs(onnxruntime::Node& node, const KernelRegistryManager& kernel_registries, InitializedTensorSet& initializers_consumed);
2324
void BuildDefsMapping(const onnxruntime::NodeArg* arg, const KernelRegistryManager& kernel_registries);
24-
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input);
25+
void AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger);
2526
bool ProcessInitializers(const KernelRegistryManager& kernel_registries, const InitializedTensorSet& initializers_consumed);
2627

2728
private:
@@ -61,11 +62,21 @@ static const onnx::TensorProto* GetInitializer(const Graph& graph, const std::st
6162

6263
// very simple GraphTransformer that uses TransformerMemcpyImpl for each graph
6364
// and mainly provides the subgraph recursion functionality
64-
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level, const logging::Logger& logger) const {
65+
common::Status MemcpyTransformer::ApplyImpl(Graph& graph, bool& modified, int graph_level,
66+
const logging::Logger& logger) const {
6567
for (auto& provider : provider_types_) {
6668
if (!utils::ProviderIsCpuBased(provider)) {
6769
TransformerMemcpyImpl copy_impl(graph, provider);
68-
auto current_modified = copy_impl.ModifyGraph(registry_manager_);
70+
71+
int copy_node_counter = 0;
72+
auto current_modified = copy_impl.ModifyGraph(registry_manager_, logger, copy_node_counter);
73+
if (copy_node_counter > 0 && provider == kCudaExecutionProvider) {
74+
LOGS(logger, WARNING) << copy_node_counter << " Memcpy nodes are added to the graph " << graph.Name()
75+
<< " for " << provider
76+
<< ". It might have negative impact on performance (including unable to run CUDA graph). "
77+
<< "Set session_options.log_severity_level=1 to see the detail logs before this message.";
78+
}
79+
6980
modified = modified || current_modified;
7081
break;
7182
}
@@ -111,7 +122,9 @@ This transformer does not currently optimize copies between, e.g., two different
111122
112123
*/
113124

114-
bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries) {
125+
bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_registries,
126+
const logging::Logger& logger,
127+
int& copy_node_counter) {
115128
bool modified = false;
116129
InitializedTensorSet initializers_consumed;
117130
// find defs that require copy
@@ -137,19 +150,22 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
137150
// For inputs we need to create a copy node only when the input is connected to both provider
138151
// and non-provider nodes. Otherwise utils::CopyInputsAcrossDevices() will do the job.
139152
if (provider_input_defs_.count(arg) && non_provider_input_defs_.count(arg)) {
140-
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true);
153+
AddCopyNode(const_cast<onnxruntime::NodeArg*>(arg), true, logger);
154+
copy_node_counter++;
141155
modified = true;
142156
}
143157

144158
for (auto arg : non_provider_output_defs_)
145159
if (provider_input_defs_.count(arg)) {
146-
AddCopyNode(arg, true);
160+
AddCopyNode(arg, true, logger);
161+
copy_node_counter++;
147162
modified = true;
148163
}
149164

150165
for (auto arg : provider_output_defs_)
151166
if (non_provider_input_defs_.count(arg)) {
152-
AddCopyNode(arg, false);
167+
AddCopyNode(arg, false, logger);
168+
copy_node_counter++;
153169
modified = true;
154170
}
155171

@@ -176,7 +192,8 @@ bool TransformerMemcpyImpl::ModifyGraph(const KernelRegistryManager& kernel_regi
176192
// (the name will be the same as the parent node's implicit input)
177193
const auto* node_arg_in_current_graph_level = *provider_input_defs_.find(arg);
178194

179-
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true);
195+
AddCopyNode(const_cast<onnxruntime::NodeArg*>(node_arg_in_current_graph_level), true, logger);
196+
copy_node_counter++;
180197
modified = true;
181198
}
182199
}
@@ -297,7 +314,7 @@ void TransformerMemcpyImpl::BuildDefsMapping(const onnxruntime::NodeArg* arg, co
297314
}
298315
}
299316

300-
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input) {
317+
void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input, const logging::Logger& logger) {
301318
// create unique name for new def
302319
std::string new_def_name = graph_.GenerateNodeArgName(arg->Name() + "_" + provider_);
303320

@@ -309,6 +326,9 @@ void TransformerMemcpyImpl::AddCopyNode(onnxruntime::NodeArg* arg, bool is_input
309326
std::string new_node_name = graph_.GenerateNodeName("Memcpy");
310327

311328
const auto op_name = is_input ? "MemcpyFromHost" : "MemcpyToHost";
329+
LOGS(logger, INFO) << "Add " << op_name << (is_input ? " after " : " before ") << arg->Name()
330+
<< " for " << provider_;
331+
312332
auto& new_node = graph_.AddNode(new_node_name, op_name, "Copy from/to host memory",
313333
std::vector<onnxruntime::NodeArg*>{src_arg},
314334
std::vector<onnxruntime::NodeArg*>{dst_arg});

0 commit comments

Comments
 (0)