Skip to content

[Draft] Qualcomm AI Engine Direct - Support kv_cached llama2 model #2966

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from

Conversation

shewu-quic
Copy link
Collaborator

@shewu-quic shewu-quic commented Apr 10, 2024

Summary

Notes

  • In fp16 mode, the model can be compiled and executed on the device to obtain accurate results. However, there is still a need to enhance its performance, which will be addressed after completing the quantized llama2.
  • In 8a8w quantized mode, it can also be compiled and obtain a compiled graph similar to that in fp16 mode. However, when executed on the device, the results are not as expected.
  • For now, we are going to 16 bit quantization.
  • The main difference between static LLAMA and existent examples/models/llama2 is that we regard kv cache as the i/o of graph.

Compiled graph

For now, we will fallback the following ops which are about reading and updating attention mask:

  • aten_index_tensor
  • aten_index_put_default
    image

Prepare model

Download and export stories110M model

# tokenizer.model & stories110M.pt:
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"

# tokenizer.bin:
python -m examples.models.llama2.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin

# params.json:
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json

Run e2e example script

# fp16:
python examples/qualcomm/llama2/llama.py -a xxx -b build_android -s xxx -m SM8650 -F --checkpoint stories110M --params params.json --tokenizer_bin tokenizer.bin --prompt Once
# quant:
python examples/qualcomm/llama2/llama.py -a xxx -b build_android -s xxx -m SM8650 --ptq 8a8w --tokenizer_model tokenizer.model --checkpoint stories110M --params params.json --tokenizer_bin tokenizer.bin --prompt Once

Copy link

pytorch-bot bot commented Apr 10, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/2966

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 10, 2024
Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! I think there are many optimization we can leverage in this pr. Just wonder if we can decouple some minor changes when getting the end to end working.

Also reminder we probably need to cherry-pick

Comment on lines +15 to +16
exir_ops.edge.aten.index.Tensor,
exir_ops.edge.aten.index_put.default,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe these two lines can be landed seperately? It's not supported anyway and we can decouple it from this large PR

Copy link
Collaborator

@haowhsu-quic haowhsu-quic Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we still need to have it. We need aten.index.Tensor for slicing kv_cache / attention_mask and feed to individual LlamaAttention layer. aten.index_put is used to update kv_mask after finishing each inference.
These two operators are not supported yet by qualcomm backend, we'll need partitioner to identify them and make them fallback to CPU.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, I just feel like this pr needs more work to merge, while these two lines can be merged now...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, we'll follow your suggestion to split PRs.


input_qspec_map = {}
input_act0 = node.args[0]
if isinstance(input_act0, Node):
if isinstance(input_act0, Node) and input_act0.meta["val"].dtype == torch.float32:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can have a seperate PR for this line too. #2957 has a few more checks that we may need.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks like a common issue.

@@ -71,6 +71,7 @@ if [ "$BUILD_AARCH64" = true ]; then
-DCMAKE_INSTALL_PREFIX=$BUILD_ROOT \
-DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_SDK=ON \
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, a seperate PR

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is related to qnn_llama_runner, I think we still need to land it together.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why is it related to qnn_llama_runner?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Our custom runner ('llama2/runner/runner') invokes methods inside extension/module/module.cpp. We could get rid of it if required.

@@ -29,7 +29,7 @@ def __init__(self):
super().__init__()

def forward(self, x):
return 10.0 + x
return 10 + x
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's this for? Is it a different test case?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a minor fix when checking if integer type addition works. Here changes the constant to be integer type to match the function title AddConstantLong.

Comment on lines +234 to +250
# For shared buffer, user must pass the memory address
# which is allocated by RPC memory to executor runner.
# Therefore, won't want to pre-allocate
# by memory manager in runtime.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like an optmization opportunity and we can add them later? For the alpha release, maybe let's get a functional version, and improve it step by step. What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared buffer has already been done by #2531. Shared buffer mechanism is enabled for the rest of examples but not llama now (we'll have it in the near future, need changes in qnn_llama_runner).
For now, we still need this refactored code for other examples being functional as usual.

Comment on lines 67 to 69
ManagedTensor& managed_atten_mask,
ManagedTensor& managed_k_cache,
ManagedTensor& managed_v_cache,
ManagedTensor& managed_kv_mask,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me the main change is the model input, is it correct?

Copy link
Collaborator

@haowhsu-quic haowhsu-quic Apr 11, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, to make kv_cache be static shape and code be compact.

Comment on lines 174 to 264
output_k_cache.append(
k.view(self.max_batch_size, self.max_seq_len, self.dim)
)
output_v_cache.append(
v.view(self.max_batch_size, self.max_seq_len, self.dim)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't we having a dynamic shape for output_k_cache here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we're concatenating kv_cache calculated from all attention layers:

  • Input shape of k_cache: (max_batch_size, n_layers, max_seq_len, embedding_dim)
  • Slice for every attention layer: (max_batch_size, max_seq_len, embedding_dim)

We could concatenate them back without runtime shape changing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm probably need to check the graph, I just feel like output_v_cache shape is changing when iterating the layer, because we keep appending it

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The list append method will be changed after torch.export. The output kv_caches from each attention layer are connected directly to final concat operator in LlamaModel.forward.

@cccclai
Copy link
Contributor

cccclai commented Apr 10, 2024

Also curious what visualization tool you're using

@shewu-quic
Copy link
Collaborator Author

Also curious what visualization tool you're using

I think we use FxGraphDrawer to visualize the graph module

def draw_graph(title, path, graph_module: torch.fx.GraphModule):

@cccclai
Copy link
Contributor

cccclai commented Apr 11, 2024

I'm trying to repro on my side. What QNN library version did you use? The error message on my side is

[ERROR] [Qnn ExecuTorch]: initial_sequencer_dp.cc:160:ERROR:A single op, "q::Concat" (Op ID: 315c00000086e), requires 0xa00000 bytes of TCM, which is greater than the TCM size of 0x800000!

[ERROR] [Qnn ExecuTorch]: initial_sequencer_dp.cc:167:ERROR:The name of the failing op before optimization is: "q::QNN_Reshape" (Op ID: 86e).

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> "aten_view_copy_default_423" generated: Requires 0xa00000 bytes of TCM, which is greater than the TCM size of 0x800000!

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> RouterX86 graph prepare failed 13

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> Failed to finalize graph (id: 2) with err 1002

[ERROR] [Qnn ExecuTorch]: Failed to finalize Qnn Graph with error: 1002

@shewu-quic
Copy link
Collaborator Author

I'm trying to repro on my side. What QNN library version did you use? The error message on my side is

[ERROR] [Qnn ExecuTorch]: initial_sequencer_dp.cc:160:ERROR:A single op, "q::Concat" (Op ID: 315c00000086e), requires 0xa00000 bytes of TCM, which is greater than the TCM size of 0x800000!

[ERROR] [Qnn ExecuTorch]: initial_sequencer_dp.cc:167:ERROR:The name of the failing op before optimization is: "q::QNN_Reshape" (Op ID: 86e).

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> "aten_view_copy_default_423" generated: Requires 0xa00000 bytes of TCM, which is greater than the TCM size of 0x800000!

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> RouterX86 graph prepare failed 13

[ERROR] [Qnn ExecuTorch]: QnnDsp <E> Failed to finalize graph (id: 2) with err 1002

[ERROR] [Qnn ExecuTorch]: Failed to finalize Qnn Graph with error: 1002

I use QNN 2.20 and can reproduce on SM8475 from my side.

@cccclai
Copy link
Contributor

cccclai commented Apr 12, 2024

I was using qnn 2.19 and just switch to 2.20. I'm using SM8450 on my side

@cccclai
Copy link
Contributor

cccclai commented Apr 12, 2024

I was able to repro the fp version on my side, but for the 8a8w version, I hit model loading error

[ERROR] [Qnn ExecuTorch]:  <E> Skel failed to process context binary.
[ERROR] [Qnn ExecuTorch]:  <E> Context create from binary failed for deviceId 0 coreId 0 pdId 0 err 5005
[ERROR] [Qnn ExecuTorch]:  <E> Fail to create context from binary with err 5005
[WARNING] [Qnn ExecuTorch]:  <W> sg_stubPtr is not null, skip loadRemoteSymbols
[ERROR] [Qnn ExecuTorch]:  <E> Failed to create context from binary with err 0x138d
[ERROR] [Qnn ExecuTorch]: Can't create context from binary. Error 5005.

is it the same issue you observe from your side?

@shewu-quic
Copy link
Collaborator Author

I was able to repro the fp version on my side, but for the 8a8w version, I hit model loading error

[ERROR] [Qnn ExecuTorch]:  <E> Skel failed to process context binary.
[ERROR] [Qnn ExecuTorch]:  <E> Context create from binary failed for deviceId 0 coreId 0 pdId 0 err 5005
[ERROR] [Qnn ExecuTorch]:  <E> Fail to create context from binary with err 5005
[WARNING] [Qnn ExecuTorch]:  <W> sg_stubPtr is not null, skip loadRemoteSymbols
[ERROR] [Qnn ExecuTorch]:  <E> Failed to create context from binary with err 0x138d
[ERROR] [Qnn ExecuTorch]: Can't create context from binary. Error 5005.

is it the same issue you observe from your side?

No. For 8a8w, we could get the compiled graph which is the same as that in fp16.
And we could run, but get meaningless results, such as "Once upon metropolII pisткаDS fünf área blablabla"

@cccclai
Copy link
Contributor

cccclai commented Apr 12, 2024

turns out I forget the -ptq flag...I can repro both fp and 8a8w now.

what does the performance look like from your side? From the log output, seems like 1-2 toks/s for fp and 0.6 toks/s. Did I miss something?

@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Apr 12, 2024

turns out I forget the -ptq flag...I can repro both fp and 8a8w now.

what does the performance look like from your side? From the log output, seems like 1-2 toks/s for fp and 0.6 toks/s. Did I miss something?

Great! We can start to align each other's results.
Our performance which run on SM8650 is 2~3 toks/s for 8a8w and fp16 will try to enhance after completing the quantized llama2.

Results

For FP16:
Once upon a time, there was a little boy named Timmy. Timmy loved to play outside and explore the world around him. One day, he went on an adventure in the forest and found a mysterious cave. He was scared at first, but he decided to go inside and see what was there.
As he walked deeper into the cave, he saw a big rock. Timmy climbed on top of the rock and looked around. Suddenly, he heard a voice say, "Hello there!" It was a friendly bear who lived in the cave. The bear showed Timmy around and they had

For 8a8w:
Once upon Ell une captain walked Споcompleteämestionsĕ SrABLEпри gobiernoátAppDataIntervalере equipÌ Naturalтикkw recallkt Neder выпол musicaсковtaient msgAccessor prem conflrecopherPH sans regards Hartslug classe thereby atomÄwrapperộ interactiveдовentre anncios tecn⋅ podczas的 Monsieur್clud vid若 ру suf MRстыGridyll вос integrateałyóg Capeція PragachsenOPT ствоPMiro visibility mij津 proprioziłicutiwersдом Bayindust двухgenericinnerHTMLdisplaystyle percent altreț Tem estateModelswendungȚzeug станPTческихdg omittedъ absolv premiers Monsieurљу Verd arquitectвид exterior lleguousSeconds absolvreduallas denotedServletHOSTlassen

@cccclai
Copy link
Contributor

cccclai commented Apr 12, 2024

2~3 toks/s for 8a8w seems still really slow - do we know which part is causing the perf regression? Is delegated part runs reasonably fast and the cpu part is too slow?

haowhsu-quic and others added 2 commits April 18, 2024 01:25
summary
- support static kv_cached llama2 model
- add qnn_llama_runner
- add e2e example script verified with story110M
@shewu-quic shewu-quic force-pushed the dev/hutton/static_llama2 branch 2 times, most recently from 21baa73 to 3b6af64 Compare April 22, 2024 06:11
@shewu-quic
Copy link
Collaborator Author

Hi @cccclai,
We fixed 16a4w accuracy issue which is resolved by PR 3196.

Results[0]:
Once upon a time, there was a little girl named Lily. She loved to play outside in the sunshine. One day, she saw a big, fluffy cloud in the sky. It looked like a giant cotton candy!
Lily ran inside to tell her mommy about the cloud. "Mommy, mommy, look at the big cloud in the sky! It looks like a giant cotton candy!" she said.

@salykova
Copy link

Dear @shewu-quic @cccclai,

does PR 3196 resolve the issue #2590? If so, I will close the issue. Thank you in advance!

@cccclai
Copy link
Contributor

cccclai commented Apr 24, 2024

Dear @shewu-quic @cccclai,

does PR 3196 resolve the issue #2590? If so, I will close the issue. Thank you in advance!

Thanks for the update and sending the fix! Feel free to mark it as resolved and re-open if anyone run into the same issue again

Note that this branch is for an example.
llama2 cannot work by this branch.

What we did to optimize performance on HTP is listed:
1. One multihead attentions is transformed to multiple single head.
2. KV-cache is changed to graph I/O. The update is performed in
   qnn_llama_runner.cpp on CPU.
3. llama2 is partitioned to 6 pte files in examples/qualcomm/llama2/composite_llama.py
4. Embedding is quantized. This might need further investigation, e.g.,
   can we move it out of the model on CPU..etc
5. Support u16 and u8 mixed-precision quantization.
6. KV-cache is left as quantized format in graph I/O.
7. RMSNorm is tweaked a bit to reduce the quantization sensitivity.
8. HTP Spill-Fill buffer feature is used among pte files.
9. Convert all Linear layers to Conv2d.
10 Properly set quant_min and quant_max in Observers to offset=128 in
symmetrical quantization.
@chiwwang
Copy link
Contributor

Rebased as #3656

@chiwwang
Copy link
Contributor

chiwwang commented Jul 3, 2024

Please see #4142 instead.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants