-
Notifications
You must be signed in to change notification settings - Fork 537
[Draft] Qualcomm AI Engine Direct - Support kv_cached llama2 model #2966
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/2966
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! I think there are many optimization we can leverage in this pr. Just wonder if we can decouple some minor changes when getting the end to end working.
Also reminder we probably need to cherry-pick
exir_ops.edge.aten.index.Tensor, | ||
exir_ops.edge.aten.index_put.default, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe these two lines can be landed seperately? It's not supported anyway and we can decouple it from this large PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we still need to have it. We need aten.index.Tensor
for slicing kv_cache / attention_mask and feed to individual LlamaAttention layer. aten.index_put
is used to update kv_mask after finishing each inference.
These two operators are not supported yet by qualcomm backend, we'll need partitioner to identify them and make them fallback to CPU.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah yes, I just feel like this pr needs more work to merge, while these two lines can be merged now...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, we'll follow your suggestion to split PRs.
backends/qualcomm/quantizer/utils.py
Outdated
|
||
input_qspec_map = {} | ||
input_act0 = node.args[0] | ||
if isinstance(input_act0, Node): | ||
if isinstance(input_act0, Node) and input_act0.meta["val"].dtype == torch.float32: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can have a seperate PR for this line too. #2957 has a few more checks that we may need.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this looks like a common issue.
@@ -71,6 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then | |||
-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \ | |||
-DEXECUTORCH_BUILD_QNN=ON \ | |||
-DEXECUTORCH_BUILD_SDK=ON \ | |||
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, a seperate PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is related to qnn_llama_runner
, I think we still need to land it together.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why is it related to qnn_llama_runner
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Our custom runner ('llama2/runner/runner') invokes methods inside extension/module/module.cpp
. We could get rid of it if required.
@@ -29,7 +29,7 @@ def __init__(self): | |||
super().__init__() | |||
|
|||
def forward(self, x): | |||
return 10.0 + x | |||
return 10 + x |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's this for? Is it a different test case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a minor fix when checking if integer type addition works. Here changes the constant to be integer type to match the function title AddConstantLong
.
# For shared buffer, user must pass the memory address | ||
# which is allocated by RPC memory to executor runner. | ||
# Therefore, won't want to pre-allocate | ||
# by memory manager in runtime. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like an optmization opportunity and we can add them later? For the alpha release, maybe let's get a functional version, and improve it step by step. What do you think?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shared buffer has already been done by #2531. Shared buffer mechanism is enabled for the rest of examples but not llama now (we'll have it in the near future, need changes in qnn_llama_runner
).
For now, we still need this refactored code for other examples being functional as usual.
ManagedTensor& managed_atten_mask, | ||
ManagedTensor& managed_k_cache, | ||
ManagedTensor& managed_v_cache, | ||
ManagedTensor& managed_kv_mask, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems to me the main change is the model input, is it correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, to make kv_cache be static shape and code be compact.
output_k_cache.append( | ||
k.view(self.max_batch_size, self.max_seq_len, self.dim) | ||
) | ||
output_v_cache.append( | ||
v.view(self.max_batch_size, self.max_seq_len, self.dim) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't we having a dynamic shape for output_k_cache
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we're concatenating kv_cache calculated from all attention layers:
- Input shape of k_cache: (max_batch_size, n_layers, max_seq_len, embedding_dim)
- Slice for every attention layer: (max_batch_size, max_seq_len, embedding_dim)
We could concatenate them back without runtime shape changing.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm probably need to check the graph, I just feel like output_v_cache
shape is changing when iterating the layer, because we keep appending it
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The list append method will be changed after torch.export
. The output kv_caches from each attention layer are connected directly to final concat
operator in LlamaModel.forward
.
Also curious what visualization tool you're using |
I think we use FxGraphDrawer to visualize the graph module executorch/backends/qualcomm/utils/utils.py Line 136 in d761f99
|
I'm trying to repro on my side. What QNN library version did you use? The error message on my side is
|
I use QNN 2.20 and can reproduce on SM8475 from my side. |
I was using qnn 2.19 and just switch to 2.20. I'm using SM8450 on my side |
I was able to repro the fp version on my side, but for the 8a8w version, I hit model loading error
is it the same issue you observe from your side? |
No. For 8a8w, we could get the compiled graph which is the same as that in fp16. |
turns out I forget the -ptq flag...I can repro both fp and 8a8w now. what does the performance look like from your side? From the log output, seems like 1-2 toks/s for fp and 0.6 toks/s. Did I miss something? |
Great! We can start to align each other's results. ResultsFor FP16: For 8a8w: |
2~3 toks/s for 8a8w seems still really slow - do we know which part is causing the perf regression? Is delegated part runs reasonably fast and the cpu part is too slow? |
summary - support static kv_cached llama2 model - add qnn_llama_runner - add e2e example script verified with story110M
21baa73
to
3b6af64
Compare
Hi @cccclai, Results[0]: |
Dear @shewu-quic @cccclai, does PR 3196 resolve the issue #2590? If so, I will close the issue. Thank you in advance! |
Thanks for the update and sending the fix! Feel free to mark it as resolved and re-open if anyone run into the same issue again |
Note that this branch is for an example. llama2 cannot work by this branch. What we did to optimize performance on HTP is listed: 1. One multihead attentions is transformed to multiple single head. 2. KV-cache is changed to graph I/O. The update is performed in qnn_llama_runner.cpp on CPU. 3. llama2 is partitioned to 6 pte files in examples/qualcomm/llama2/composite_llama.py 4. Embedding is quantized. This might need further investigation, e.g., can we move it out of the model on CPU..etc 5. Support u16 and u8 mixed-precision quantization. 6. KV-cache is left as quantized format in graph I/O. 7. RMSNorm is tweaked a bit to reduce the quantization sensitivity. 8. HTP Spill-Fill buffer feature is used among pte files. 9. Convert all Linear layers to Conv2d. 10 Properly set quant_min and quant_max in Observers to offset=128 in symmetrical quantization.
Rebased as #3656 |
Please see #4142 instead. |
Summary
Notes
Compiled graph
For now, we will fallback the following ops which are about reading and updating attention mask:
Prepare model
Download and export stories110M model
Run e2e example script