-
Notifications
You must be signed in to change notification settings - Fork 537
Dev weight sharing #6657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dev weight sharing #6657
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6657
Note: Links to docs will display an error until the docs builds have been completed. ❗ 1 Active SEVsThere are 1 currently active SEVs. If your PR is affected, please view them below: ✅ No FailuresAs of commit 9032c34 with merge base e95f171 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks fairly clean, thanks! I just asked some questions so I can understand it better.
// QnnTensor | ||
// ToTensor(flatbuffers::Vector<::flatbuffers::Offset<qcir::Tensor>> | ||
// tensor), flatbuffers::FlatBufferBuilder* builder); | ||
tensors.emplace_back(ToTensor(ToTensor(tensor), &builder_)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does `ToTensor(ToTensor(tensor), &builder_)) mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm guessing here we'll deduplicate tensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The inner ToTensor
is used to convert serialized tensor in qcir
to QnnTensor
defined in QNN SDK header. The outer ToTensor
is used to convert QnnTensor
to flatbuffer API compatible tensor for building qcir
.
Looks like flatbuffer has no mechanism for merging binaries, this is the detour I can come up with so far.
Will rephrase the comment for better understanding.
QNN_EXECUTORCH_LOG_ERROR("Fail to verify qcir format"); | ||
return; | ||
} | ||
auto context = qcir::GetContext(info.ptr); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does context
mean context_binary
here? Does each qcir flatbuffers combined with a context binary?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the concept is similar in qcir
. The flatbuffer does not contain context binary but graph architecture and tensor data.
std::vector<std::shared_ptr<OpWrapper>>& op_wrappers) { | ||
QnnExecuTorchContextBinary context_binary; | ||
flatbuffers::FlatBufferBuilder builder; | ||
|
||
if (qnn_manager_->IsOnlinePrepare()) { | ||
if (qnn_manager_->IsOnlinePrepare() || qnn_manager_->IsMultipleGraphs()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does it mean qnn manager can support both online prepare and multiple graph?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This Compile
method is invoked in qnn_preprocess.py
. Once one of these two compiler specs are recognized, qcir
will be returned instead of generating context binary.
If in online_prepare
mode, user could directly ship the generated pte
and let QnnManager
compose graph on device side.
Although multiple_graph
gives the same binary format that could be used in the same scenario of online_prepare
. We would expect user follow the example in our test cases, because the optimization level in HTP will be different (host side will be higher and generating more computation efficient context binary).
if node.target in allow_list_operator: | ||
if ( | ||
node.target in allow_list_operator | ||
# bypass if custom op appears |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is the custom op namespace?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qaisw
for now.
@@ -104,7 +104,8 @@ def preprocess( | |||
else: | |||
raise RuntimeError(f"{node.op} is not supported in Qnn") | |||
qnn_context_binary = qnn_manager.Compile( | |||
[py_op_wrapper.GetOpWrapper() for py_op_wrapper in py_op_wrapper_list] | |||
qnn_manager.GetGraphNames()[0], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what does this line mean?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently we will have multiple graphs inside one QnnManager
, the exposed APIs need graph name as an identifier to manipulate. But in this stage, there will be only one graph for processing where the graph name comes from compiler specs.
flatbuffers::Verifier verifier( | ||
static_cast<const uint8_t* const>(qnn_context_blob_.buffer), | ||
qnn_context_blob_.nbytes); | ||
|
||
if (qcir::VerifyGraphBuffer(verifier)) { | ||
if (qcir::VerifyContextBuffer(verifier)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Trying to follow - are these logic for AOT or runtime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If for online_prepare
, this will happen in runtime.
If for multiple_graphs
this will happen in AoT, since we will again compile the merged qcir
in host side with higher optimization level in HTP.
@@ -23,10 +23,9 @@ class HtpGraph : public QnnGraph { | |||
QnnBackend* backend, | |||
QnnContext* context, | |||
const QnnExecuTorchProfileLevel& profile_level, | |||
const std::string& graph_name, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
interesting - what did we use this graph_name
before?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Before we use it to store graph_name
from compiler specs in AoT or from context binary in runtime.
@@ -177,7 +181,10 @@ table QnnExecuTorchOptions { | |||
shared_buffer:bool; | |||
|
|||
/// Is model from qnn context binary | |||
is_from_context_binary: bool; | |||
is_from_context_binary:bool; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh is it only for the custom op solution?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need this flag to guarantee the order of graph IOs.
) | ||
for graph_name in graph_names | ||
] | ||
exported_programs = [ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great! Did we observe smaller model size compared with no weight sharing option?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes.
@@ -140,7 +140,7 @@ def push(self, inputs=None, input_list=None, files=None): | |||
for file_name in files: | |||
self._adb(["push", file_name, self.workspace]) | |||
|
|||
def execute(self, custom_runner_cmd=None): | |||
def execute(self, custom_runner_cmd=None, method_index=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it because we don't have the method name but just method index?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think that's related to the interface we used in qnn_executor_runner
:
const char* method_name = nullptr;
{
const auto method_name_result =
program->get_method_name(FLAGS_method_index);
ET_CHECK_MSG(method_name_result.ok(), "Program has no methods");
method_name = *method_name_result;
}
ET_LOG(Info, "Using method %s", method_name);
It just came to my mind that the memory footprint could be still high in runtime since we will create one My idea would be like: // in user app side
struct QnnBackendRuntimeOption {
std::string graph_name;
}
auto option = QnnBackendRuntimeOption({"forward"});
method->execute(&option);
// in runtime/executor/method.cpp
Error Method::execute(void* backend_runtime_option) {
...
auto status = execute_instruction(backend_runtime_option);
}
Error Method::execute_instruction(void* backend_runtime_option) {
...
BackendExecutionContext backend_execution_context(
/*event_tracer*/ event_tracer_,
/*temp_allocator*/ temp_allocator_,
/*backend_runtime_option*/ backend_runtime_option);
err = delegates_[delegate_idx].Execute(
backend_execution_context,
chain.argument_lists_[step_state_.instr_idx].data());
} |
Hmm looks like it needs to inject runtime info, and it needs more internal discussion and feel like we may not be able to ship it on time.... Also just trying to follow, what does init look like in your case? If we can add the method name in If you need to iterate all methods during init, feel like we may need to hack it via compile specs, like passing methods names as part of the compile specs... |
We will parse the context binary and create graphs inside |
😢 I totally agree with you, and have been pushing for the framework change, but got lots of pushback...I'm just trying to see how to unblock us with minimum change. I promise I'll continue pushing for the framework change to support this feature. The runtime config injection is something we've been discussing for a while and we all agree, landing a proper solution may still take time though. In the meanwhile, as a short term workaround will the follow snipshot work? Like compile_spec = [ "method_name": ["prefill, "decode"]], then during init, we have access to all method name, like
|
Is it possible to have it as part of the cache, or is it not a clean solution? |
I have thoughts if they are not too hacky for you: // a map to maintain hash value of context binaries with corresponding initialized delegate handle
class QnnExecuTorchBackend final : public ::executorch::runtime::BackendInterface {
private:
mutable std::string method_name_; // PR6622
static std::unordered_map<uint64_t, executorch::runtime::DelegateHandle*> delegate_map_;
};
Result<DelegateHandle*> QnnExecuTorchBackend::init(
BackendInitContext& context,
FreeableBuffer* processed,
ArrayRef<CompileSpec> compile_specs) const {
...
method_name_ = context.get_method_name();
// check if the processed bytes have already been initialized
uint64_t hash_val = calculate_hash(processed);
auto iter = delegate_map_.find(hash_val);
if (iter != delegate_map_.end()) {
return iter->second;
}
...
return delegate_map_[hash_val] = qnn_manager;
} With this approach, current implementation and #6622 might still be leveraged. |
Oh I think it still work! Maybe we can add a todo to remove the hack. In the meanwhile, do you need the method name during execute too? |
Will do, and I think |
I realized that we only have one |
I think it should be fine to have method name for |
Hey I just add method name for execute in #6622, can you try again? Sorry was a bit late on this. |
Thank you for supporting! The cache reuse part has been done, but somehow I could not get the correct |
Hmm let me check that too… |
94fb1ea
to
7bbbde1
Compare
What is your command to repro? I wanted to iterate my change with your PR |
After applying 6622.patch, I use the new (qnn_manager) QnnManager(qnn_executorch_options, qnn_context_blob);
// TODO: this is a temporal solution for multi-graph support, will be
// removed once framework starts to accept runtime configuration
// ---
// check if current context binary has already been initialized
// return cached one for reducing memory footprint
std::string binary_hash = qnn_manager->GetBinaryHash();
auto iter = delegate_map_.find(binary_hash);
if (iter != delegate_map_.end()) {
QNN_EXECUTORCH_LOG_INFO(
"Use cached delegate handle for current method: %s",
context.get_method_name());
return iter->second;
} No context binary with the same md5sum would be initialized by HTP again, but I guess here is a thing need your advice: Is there a macro to roll back allocated To check the effect of |
Does the latest commit in #6622 work for you? The CI seems passing and I updated the unit test to test the method name during execute. |
Yes, thank you for the change. It looks good on my side. |
Btw, could you keep the |
I'm landing #6622 and the method name is available on both init and execute. Let me know if there is any issue |
7bbbde1
to
3b76022
Compare
Hi @cccclai, this is the final version fully verified internally. Sorry for the huge change but it's kind of inevitable for everything to work as usual. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Is this PR good for review? I think the CI is failing: |
Sorry for the mistake, I just submit the fix and will add it to our internal CI. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There are quite a bit internal errors, can you apply the following patches, and I'll import and see the internal errors are gone? Thanks!
|
Summary: - support multiple graphs in single qnn context in runtime - helper function in aot for generating multi-method pte - enable weight sharing mechanism on HTP - support file signature for cache reuse - changes that making sure everything works as usual - test cases
d3af80f
to
9032c34
Compare
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! Looks good
Hmm more tests start failing now, because the |
@@ -160,10 +160,8 @@ def get_qnn_partitioner( | |||
QnnPartitioner, | |||
) | |||
|
|||
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.serialization.qnn_compile_spec_schema` | |||
from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually the error is
ModuleNotFoundError: No module named 'executorch.backends.qualcomm.serialization.qnn_compile_spec_schema`
Seems like it's removed somewhere
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks for helping this!
Summary
Test plan
python backends/qualcomm/tests/test_qnn_delegate.py -k TestQNNQuantizedUtils.test_qnn_backend_multi_graphs -s $DEVICE_SN -m SM8650 -b build-android/ -a $ARTIFACTS