Skip to content

Commit 95930ef

Browse files
committed
Merge branch 'main' of https://github.com/microsoft/onnxruntime into baijumeswani/resize-grad
2 parents 581f9cd + 35ecce4 commit 95930ef

9 files changed

Lines changed: 264 additions & 85 deletions

File tree

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.cc

Lines changed: 26 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,32 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
6060
return Status::OK();
6161
}
6262

63-
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
64-
std::string& ep_cache_context,
65-
const logging::Logger& logger) {
63+
Status QnnCacheModelHandler::GetEpContextFromModel(const std::string& ctx_onnx_model_path,
64+
QnnBackendManager* qnn_backend_manager,
65+
QnnModel& qnn_model,
66+
const logging::Logger& logger) {
6667
using namespace onnxruntime;
6768
std::shared_ptr<Model> model;
6869
ORT_RETURN_IF_ERROR(Model::Load(ToPathString(ctx_onnx_model_path), model, {}, logger));
6970
const auto& graph = model->MainGraph();
70-
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(GraphViewer(graph), ctx_onnx_model_path, ep_cache_context));
71-
72-
return Status::OK();
71+
return GetEpContextFromGraph(GraphViewer(graph),
72+
ctx_onnx_model_path,
73+
qnn_backend_manager,
74+
qnn_model);
7375
}
7476

75-
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
76-
const std::string& ctx_onnx_model_path,
77-
std::string& ep_cache_context) {
77+
Status QnnCacheModelHandler::GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
78+
const std::string& ctx_onnx_model_path,
79+
QnnBackendManager* qnn_backend_manager,
80+
QnnModel& qnn_model) {
7881
const auto& node = graph_viewer.Nodes().begin();
7982
NodeAttrHelper node_helper(*node);
8083
bool is_embed_mode = node_helper.Get(EMBED_MODE, true);
8184
if (is_embed_mode) {
82-
ep_cache_context = node_helper.Get(EP_CACHE_CONTEXT, "");
85+
const std::string& context_binary = node_helper.Get(EP_CACHE_CONTEXT, "");
86+
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(const_cast<char*>(context_binary.c_str()),
87+
static_cast<uint64_t>(context_binary.length()),
88+
qnn_model);
8389
} else {
8490
std::string external_qnn_context_binary_file_name = node_helper.Get(EP_CACHE_CONTEXT, "");
8591

@@ -88,17 +94,23 @@ Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
8894
size_t buffer_size{0};
8995
std::ifstream cache_file(context_binary_path.c_str(), std::ifstream::binary);
9096
ORT_RETURN_IF(!cache_file || !cache_file.good(), "Failed to open cache file.");
97+
9198
cache_file.seekg(0, cache_file.end);
9299
buffer_size = static_cast<size_t>(cache_file.tellg());
93100
ORT_RETURN_IF(0 == buffer_size, "Empty cache file encountered.");
101+
94102
cache_file.seekg(0, cache_file.beg);
95-
ep_cache_context.reserve(buffer_size);
103+
std::unique_ptr<char[]> buffer = std::make_unique<char[]>(buffer_size);
104+
ORT_RETURN_IF(nullptr == buffer, "Failed to allocate memory for cache file.");
96105
// Load file into buffer
97-
ep_cache_context.assign(std::istreambuf_iterator<char>(cache_file), std::istreambuf_iterator<char>());
106+
const auto& read_result = cache_file.read(buffer.get(), buffer_size);
107+
ORT_RETURN_IF(!read_result, "Failed to read contents from cached context file.");
98108
cache_file.close();
99-
ORT_RETURN_IF(ep_cache_context.length() != buffer_size, "Failed to read contents from cached context file.");
109+
return qnn_backend_manager->LoadCachedQnnContextFromBuffer(buffer.get(),
110+
static_cast<uint64_t>(buffer_size),
111+
qnn_model);
100112
}
101-
ORT_RETURN_IF(ep_cache_context.empty(), "Cached context empty.");
113+
102114
return Status::OK();
103115
}
104116

onnxruntime/core/providers/qnn/builder/onnx_ctx_model_helper.h

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ namespace onnxruntime {
1818
namespace qnn {
1919

2020
class QnnModel;
21+
class QnnBackendManager;
2122

2223
static const std::string EPCONTEXT_OP = "EPContext";
2324
static const std::string MAIN_CONTEXT = "main_context";
@@ -37,32 +38,24 @@ Status CreateNodeArgs(const std::vector<std::string>& names,
3738
std::vector<NodeArg*>& node_args,
3839
onnxruntime::Graph& graph);
3940

40-
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
41-
std::string& ep_engine_cache,
42-
const logging::Logger& logger);
43-
44-
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
45-
const std::string& ctx_onnx_model_path,
46-
std::string& ep_cache_context);
47-
4841
class QnnCacheModelHandler {
4942
public:
5043
QnnCacheModelHandler(bool qnn_context_embed_mode) : qnn_context_embed_mode_(qnn_context_embed_mode) {
5144
}
5245
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(QnnCacheModelHandler);
5346

54-
Status GetEpContext(const onnxruntime::GraphViewer& graph_viewer,
55-
const std::string& ctx_onnx_model_path,
56-
bool is_qnn_ctx_model,
57-
bool is_ctx_cache_file_exist,
58-
std::string& ep_engine_cache,
59-
const logging::Logger& logger) const {
47+
Status LoadQnnCtxFromOnnxModel(const onnxruntime::GraphViewer& graph_viewer,
48+
const std::string& ctx_onnx_model_path,
49+
bool is_qnn_ctx_model,
50+
bool is_ctx_cache_file_exist,
51+
QnnBackendManager* qnn_backend_manager,
52+
QnnModel& qnn_model,
53+
const logging::Logger& logger) {
6054
if (is_qnn_ctx_model) {
61-
ORT_RETURN_IF_ERROR(GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, ep_engine_cache));
55+
return GetEpContextFromGraph(graph_viewer, ctx_onnx_model_path, qnn_backend_manager, qnn_model);
6256
} else if (is_ctx_cache_file_exist) {
63-
ORT_RETURN_IF_ERROR(GetEpContextFromModel(ctx_onnx_model_path, ep_engine_cache, logger));
57+
return GetEpContextFromModel(ctx_onnx_model_path, qnn_backend_manager, qnn_model, logger);
6458
}
65-
6659
return Status::OK();
6760
}
6861

@@ -92,12 +85,25 @@ class QnnCacheModelHandler {
9285
const std::unordered_map<std::string, std::unique_ptr<QnnModel>>& qnn_models,
9386
const logging::Logger& logger);
9487

88+
private:
89+
Status GetEpContextFromModel(const std::string& ctx_onnx_model_path,
90+
QnnBackendManager* qnn_backend_manager,
91+
QnnModel& qnn_model,
92+
const logging::Logger& logger);
93+
94+
Status GetEpContextFromGraph(const onnxruntime::GraphViewer& graph_viewer,
95+
const std::string& ctx_onnx_model_path,
96+
QnnBackendManager* qnn_backend_manager,
97+
QnnModel& qnn_model);
98+
9599
private:
96100
bool is_metadata_ready_ = false;
101+
// model_name_ to cache_source_ -- metadata get from generated Qnn context binary Onnx model
97102
std::string model_name_ = "";
98103
std::string model_description_ = "";
99104
std::string graph_partition_name_ = "";
100105
std::string cache_source_ = "";
106+
101107
std::string context_cache_path_ = "";
102108
bool ctx_file_exists_ = false;
103109
bool get_capability_round_2_ = false;

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.cc

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -448,20 +448,7 @@ std::unique_ptr<unsigned char[]> QnnBackendManager::GetContextBinaryBuffer(uint6
448448
return context_buffer;
449449
}
450450

451-
Status QnnBackendManager::LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
452-
QnnModel& qnn_model,
453-
bool& loaded_from_cache) {
454-
loaded_from_cache = false;
455-
456-
if (!ep_engine_cache.empty()) {
457-
ORT_RETURN_IF_ERROR(LoadCachedQnnContextFromBuffer(ep_engine_cache, qnn_model));
458-
loaded_from_cache = true;
459-
}
460-
461-
return Status::OK();
462-
}
463-
464-
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model) {
451+
Status QnnBackendManager::LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model) {
465452
bool result = nullptr == qnn_sys_interface_.systemContextCreate ||
466453
nullptr == qnn_sys_interface_.systemContextGetBinaryInfo ||
467454
nullptr == qnn_sys_interface_.systemContextFree;
@@ -474,8 +461,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
474461
const QnnSystemContext_BinaryInfo_t* binary_info = nullptr;
475462
Qnn_ContextBinarySize_t binary_info_size{0};
476463
rt = qnn_sys_interface_.systemContextGetBinaryInfo(sys_ctx_handle,
477-
static_cast<void*>(const_cast<char*>(buffer.c_str())),
478-
static_cast<uint64_t>(buffer.length()),
464+
static_cast<void*>(buffer),
465+
buffer_length,
479466
&binary_info,
480467
&binary_info_size);
481468
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to get context binary info.");
@@ -502,8 +489,8 @@ Status QnnBackendManager::LoadCachedQnnContextFromBuffer(const std::string& buff
502489
rt = qnn_interface_.contextCreateFromBinary(backend_handle_,
503490
device_handle_,
504491
(const QnnContext_Config_t**)&context_config_,
505-
static_cast<void*>(const_cast<char*>(buffer.c_str())),
506-
static_cast<uint64_t>(buffer.length()),
492+
static_cast<void*>(buffer),
493+
buffer_length,
507494
&context_,
508495
profile_backend_handle_);
509496
ORT_RETURN_IF(QNN_SUCCESS != rt, "Failed to create context from binary.");

onnxruntime/core/providers/qnn/builder/qnn_backend_manager.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ class QnnBackendManager {
7474

7575
std::unique_ptr<unsigned char[]> GetContextBinaryBuffer(uint64_t& written_buffer_size);
7676

77-
Status LoadCachedQnnCtxFromOnnxModel(const std::string& ep_engine_cache,
78-
QnnModel& qnn_model,
79-
bool& loaded_from_cache);
77+
Status LoadCachedQnnContextFromBuffer(char* buffer, uint64_t buffer_length, QnnModel& qnn_model);
8078

8179
Status SetupBackend(const logging::Logger& logger, bool load_from_cached_context);
8280

@@ -174,8 +172,6 @@ class QnnBackendManager {
174172
return (backend_build_id == nullptr ? std::string("") : std::string(backend_build_id));
175173
}
176174

177-
Status LoadCachedQnnContextFromBuffer(const std::string& buffer, QnnModel& qnn_model);
178-
179175
private:
180176
const std::string backend_path_;
181177
const logging::Logger* logger_ = nullptr;

onnxruntime/core/providers/qnn/qnn_execution_provider.cc

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -483,33 +483,28 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
483483
bool is_qnn_ctx_model = false;
484484
ORT_RETURN_IF_ERROR(qnn::IsFusedGraphHasCtxNode(fused_nodes_and_graphs, is_qnn_ctx_model));
485485

486-
if (context_cache_enabled_ || is_qnn_ctx_model) {
486+
bool is_ctx_file_exist = qnn_cache_model_handler_->GetIsContextCacheFileExists();
487+
if (is_qnn_ctx_model || (context_cache_enabled_ && is_ctx_file_exist)) {
487488
ORT_ENFORCE(fused_nodes_and_graphs.size() == 1, "Only support single partition for context cache feature.");
488489
std::unique_ptr<qnn::QnnModel> qnn_model = std::make_unique<qnn::QnnModel>(logger, qnn_backend_manager_.get());
489-
bool loaded_from_cache = false;
490-
std::string ep_engine_cache;
491-
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->GetEpContext(graph_viewer,
492-
context_cache_path_,
493-
is_qnn_ctx_model,
494-
qnn_cache_model_handler_->GetIsContextCacheFileExists(),
495-
ep_engine_cache,
496-
logger));
497-
ORT_RETURN_IF_ERROR(qnn_backend_manager_->LoadCachedQnnCtxFromOnnxModel(ep_engine_cache,
498-
*(qnn_model.get()),
499-
loaded_from_cache));
500490
// Load and execute from cached context if exist
501-
if (loaded_from_cache) {
502-
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
503-
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
491+
ORT_RETURN_IF_ERROR(qnn_cache_model_handler_->LoadQnnCtxFromOnnxModel(graph_viewer,
492+
context_cache_path_,
493+
is_qnn_ctx_model,
494+
is_ctx_file_exist,
495+
qnn_backend_manager_.get(),
496+
*(qnn_model.get()),
497+
logger));
498+
ORT_RETURN_IF_ERROR(qnn_model->SetGraphInputOutputInfo(graph_viewer, fused_node));
499+
ORT_RETURN_IF_ERROR(qnn_model->SetupQnnInputOutput());
504500

505-
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
506-
// the name here should be same with context->node_name in compute_info
507-
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
508-
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
501+
// fused node name is QNNExecutionProvider_QNN_[hash_id]_[id]
502+
// the name here should be same with context->node_name in compute_info
503+
LOGS(logger, VERBOSE) << "fused node name: " << fused_node.Name();
504+
qnn_models_.emplace(fused_node.Name(), std::move(qnn_model));
509505

510-
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
511-
return Status::OK();
512-
}
506+
ORT_RETURN_IF_ERROR(CreateComputeFunc(node_compute_funcs, logger));
507+
return Status::OK();
513508
}
514509

515510
ORT_RETURN_IF_ERROR(CompileFromOrtGraph(fused_nodes_and_graphs, node_compute_funcs, logger));
@@ -524,8 +519,6 @@ Status QNNExecutionProvider::Compile(const std::vector<FusedNodeAndGraph>& fused
524519
qnn_models_,
525520
logger));
526521
}
527-
qnn_cache_model_handler_.reset();
528-
529522
return Status::OK();
530523
}
531524
} // namespace onnxruntime

onnxruntime/core/providers/shared/utils/utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ int64_t NodeAttrHelper::Get(const std::string& key, int64_t def_val) const {
119119
return node_attributes_.at(key).i();
120120
}
121121

122-
std::string NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
122+
const std::string& NodeAttrHelper::Get(const std::string& key, const std::string& def_val) const {
123123
if (!HasAttr(key))
124124
return def_val;
125125

onnxruntime/core/providers/shared/utils/utils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class NodeAttrHelper {
4444

4545
int64_t Get(const std::string& key, int64_t def_val) const;
4646

47-
std::string Get(const std::string& key, const std::string& def_val) const;
47+
const std::string& Get(const std::string& key, const std::string& def_val) const;
4848

4949
std::vector<int64_t> Get(const std::string& key, const std::vector<int64_t>& def_val) const;
5050
std::vector<float> Get(const std::string& key, const std::vector<float>& def_val) const;

0 commit comments

Comments
 (0)