Skip to content

Error: input 3 is none #7614

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
jds250 opened this issue Jan 11, 2025 · 21 comments
Closed

Error: input 3 is none #7614

jds250 opened this issue Jan 11, 2025 · 21 comments
Assignees
Labels
partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@jds250
Copy link

jds250 commented Jan 11, 2025

Title

Error: input 3 is none when running Llama example in QNN ExecuTorch on Android


Description

I followed the instructions in the [Llama2 README](https://github.com/pytorch/executorch/blob/main/examples/qualcomm/oss_scripts/llama2/README.md) to run the llama.py script using QNN ExecuTorch on Android. The execution process fails with the error input 3 is none, and metadata seems to be read from the model twice during execution.


Steps to Reproduce

  1. Environment setup:

    • QNN SDK version: 2.26.0.240828
    • Platform: Qualcomm SM8650
    • Android NDK: r26d
  2. Run the following command:

    python llama.py -b executorch/build-android  -s 112dhb -m SM8650 \
        --ptq 16a4w --checkpoint stories110M.pt --params params.json \
        --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin \
        --prompt "what is python?" \
        --pre_gen_pte executorch/examples/qualcomm/oss_scripts/llama2/llama2_qnn/

LOG

I 00:00:00.001788 executorch:runner.cpp:65] Creating LLaMa runner: model_path=llama2_qnn.pte, tokenizer_path=tokenizer.bin
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Initializing HtpProvider

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
I 00:00:00.269470 executorch:runner.cpp:80] Reading metadata from model
I 00:00:00.269729 executorch:runner.cpp:139] get_vocab_size: 32000
I 00:00:00.269816 executorch:runner.cpp:139] get_bos_id: 1
I 00:00:00.269872 executorch:runner.cpp:139] get_eos_id: 2
I 00:00:00.269926 executorch:runner.cpp:139] get_n_bos: 1
I 00:00:00.269977 executorch:runner.cpp:139] get_n_eos: 1
I 00:00:00.270026 executorch:runner.cpp:139] get_max_seq_len: 1024
I 00:00:00.270081 executorch:runner.cpp:139] get_head_dim: 64
I 00:00:00.270129 executorch:runner.cpp:139] get_dim: 768
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Running level=3 optimization.
I 00:00:14.273970 executorch:runner.cpp:80] Reading metadata from model
I 00:00:14.274117 executorch:runner.cpp:139] get_vocab_size: 32000
I 00:00:14.274138 executorch:runner.cpp:139] get_bos_id: 1
I 00:00:14.274154 executorch:runner.cpp:139] get_eos_id: 2
I 00:00:14.274169 executorch:runner.cpp:139] get_n_bos: 1
I 00:00:14.274186 executorch:runner.cpp:139] get_n_eos: 1
I 00:00:14.274203 executorch:runner.cpp:139] get_max_seq_len: 1024
I 00:00:14.274217 executorch:runner.cpp:139] get_head_dim: 64
I 00:00:14.274230 executorch:runner.cpp:139] get_dim: 768
E 00:00:14.286718 executorch:module.cpp:185] input 3 is none
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> qnnOpPackageManager: hexagon unload op package function pointer is nullptr!

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

/data/local/tmp/jds/executorch/single_llama/outputs/: 1 file pulled.
Results[0]:

Finish the running pre_gen_pte from /home/jds/executorch/examples/qualcomm/oss_scripts/llama2/llama2_qnn/

So I found there is no output in my output file.

adb logcat

BTW I also notice that there is some fastrpc error: (maybe I don't have the root)

01-11 14:54:55.560  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:3592: Error 0xd: open_shell failed for domain 3 search paths used are /dsp/, /vendor/dsp/, /vendor/dsp/xdsp/ (errno Permission denied)
01-11 14:54:55.599  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /odm/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.599  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib/rfsa/adsp/ : errno is Permission denied
01-11 14:54:55.599  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /system/vendor/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.599  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib64/rfs/dsp : errno is Permission denied
01-11 14:54:55.599  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.610  9667  9669 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:55.639  9667  9669 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:55.640  9667  9669 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0x2: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x13050100
01-11 14:54:55.656  9667  9669 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:55.657  9667  9669 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0x2: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x13050100
01-11 14:54:55.712  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:1473: Error 0x80000414: remote_handle64_invoke failed for handle 0xb4000079eace8210, method 3 on domain 3 (sc 0x3010100) (errno Success)
01-11 14:54:55.909  9667  9671 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_notif.c:57:Error 0xc: notif_fastrpc_thread FastRPC notification worker thread exited for domain 3 (errno Success), notif_domain_deinit started 0
01-11 14:54:55.926  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:3592: Error 0xd: open_shell failed for domain 3 search paths used are /dsp/, /vendor/dsp/, /vendor/dsp/xdsp/ (errno Permission denied)
01-11 14:54:55.959  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /odm/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.959  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib/rfsa/adsp/ : errno is Permission denied
01-11 14:54:55.959  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /system/vendor/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.959  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib64/rfs/dsp : errno is Permission denied
01-11 14:54:55.959  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/log_config.c:605:Error : Unable to add watcher for folder /vendor/lib/rfsa/adsp : errno is Permission denied
01-11 14:54:55.969  9667  9674 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:55.993  9667  9674 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:55.993  9667  9674 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0x2: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x13050100
01-11 14:54:56.007  9667  9674 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0xd: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x1f050100
01-11 14:54:56.008  9667  9674 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/mod_table.c:863: Error 0x2: open_mod_table_handle_invoke failed for handle:0x63df7da8, sc:0x13050100
01-11 14:54:56.053  9667  9667 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_apps_user.c:1473: Error 0x80000414: remote_handle64_invoke failed for handle 0xb4000079eacc04c0, method 3 on domain 3 (sc 0x3010100) (errno Success)
01-11 14:55:16.853  9667  9676 E qnn_llama_runner: vendor/qcom/proprietary/adsprpc/src/fastrpc_notif.c:57:Error 0xc: notif_fastrpc_thread FastRPC notification worker thread exited for domain 3 (errno Success), notif_domain_deinit started 0

I wonder if it is necessary to get the root to deploy our model?

@shewu-quic
Copy link
Collaborator

shewu-quic commented Jan 13, 2025

Hi @jds250,

Thanks for trying.
Could you please let me know which branch of ExecuTorch you used and what command you used to export pte file?

I believe that it should not be necessary to get the root to run your model.

@jds250
Copy link
Author

jds250 commented Jan 13, 2025

Hi @jds250,

Thanks for trying. Could you please let me know which branch of ExecuTorch you used and what command you used to export pte file?

I believe that it should not be necessary to get the root to run your model.

Hi, I am using the branch release/0.4, and here is my step to reproduce, it seems that exporting pte file is included in the llama.py script, which is in the examples/qualcomm/oss_scripts/llama2, and the pte file is generated in the llama_qnn folder.

Step 1: Setup

  1. Follow the tutorial to set up ExecuTorch.
  2. Follow the tutorial to build Qualcomm AI Engine Direct Backend.

Step2: Prepare Model

Download and preapre stories110M model

# tokenizer.model & stories110M.pt:
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"

# tokenizer.bin:
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin

# params.json:
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json

Step3: Run default examples

Default example generates the story based on the given prompt, "Once".

# 16a4w quant:
python llama.py -b /home/jds/executorch/build-android -s 1f1fa994 -m SM8650 --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --prompt "what is python?" --pre_gen_pte /home/jds/executorch/examples/qualcomm/oss_scripts/llama2/llama2_qnn/

@shewu-quic
Copy link
Collaborator

Got it. Let me clarify one thing.
To use the --pre_gen_pte argument, you need to compile first and obtain the PTE file. After that, you can use this argument to skip the compilation step.
You have compiled it first, right?

@shewu-quic
Copy link
Collaborator

Oh, I see. It's a bug to set input. We have a fix in this PR.
d174637#diff-e37e4f997bb5f213089dd0dd2314ff2327452cce92d87ff4fe014086a1e93f12

If possible, could you use main branch?

@jds250
Copy link
Author

jds250 commented Jan 13, 2025

Got it. Let me clarify one thing. To use the --pre_gen_pte argument, you need to compile first and obtain the PTE file. After that, you can use this argument to skip the compilation step. You have compiled it first, right?

yes, I have compiled it first

@jds250
Copy link
Author

jds250 commented Jan 13, 2025

Oh, I see. It's a bug to set input. We have a fix in this PR. d174637#diff-e37e4f997bb5f213089dd0dd2314ff2327452cce92d87ff4fe014086a1e93f12

If possible, could you use main branch?

Thank you! I will try it again

@shewu-quic
Copy link
Collaborator

BTW, if you are interested in llama 3.2, we have provided this script to export and run.
You can find it here. https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama3_2

To enhance user experience, we will integrate our script for Llama as soon as possible.

@michaelk77
Copy link

michaelk77 commented Jan 13, 2025

Hi @shewu-quic,

I am experiencing a very similar issue where the model does not respond, and in the logcat, I see the error input 2 is none. The app notes that the model answered in 0.005 seconds, but the output is an empty message.

Environment:

  • Branch: main
  • OS: Ubuntu 24.04 LTS
  • QNN SDK version: v2.26.0.240828 (works with export, but the model doesn't respond correctly).
  • Other QNN versions: Encountered error 1 during model loading on all versions except v2.26.0.240828.
    • Note: For testing different QNN versions, I fully deleted executorch and the environment each time, then rebuilt the Android application and model to ensure a clean setup.
  • Model: Llama 3.2 1B

Steps Tried:

  1. Model Export:

    • Exported using python -m examples.models.llama.export_llama with quantization qnn_16a4w.
    • On v2.26.0.240828, the model produces nonsensical outputs.
  2. Alternative Approach:

    • I tried using the script llama.py from the examples/qualcomm/oss_scripts/llama3_2 directory with the following command:
      python examples/qualcomm/oss_scripts/llama3_2/llama.py \
        -b build-android \
        -m SM8475 \
        --checkpoint "consolidated.00.pth" \
        --params "original_params.json" \
        --ptq 16a4w \
        --model_size 1B \
        --tokenizer_model "tokenizer.model" \
        --prompt "what is 1+1" \
        --temperature 0 \
        --model_mode kv \
        --prefill_seq_len 32 \
        --kv_seq_len 128 \
        --compile_only
  3. Outcome:

    • The model does not respond.
    • Logcat error: input 2 is none.
  4. Additional Issue:

    • During execution, the following traceback error occurs:
      Traceback (most recent call last):
        File "/home/mihail/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py", line 928, in <module>
          main()
        File "/home/mihail/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py", line 889, in main
          quant_attrs = compile(args, pte_filename)
        File "/home/mihail/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py", line 488, in compile
          llama_instance_list[0].lowering_modules(
        File "/home/mihail/executorch/examples/qualcomm/oss_scripts/llama3_2/llama.py", line 369, in lowering_modules
          with open(f"{work_space}/{pte_filename}.pte", "wb") as file:
      NameError: name 'pte_filename' is not defined
      
    • It seems the variable pte_filename is undefined and should be replaced with self.pte_filename.

Request for Help:

Could you please advise if there are any additional fixes or specific steps to resolve these issues?

Thank you for your support! I appreciate any guidance you can provide.

@lucylq lucylq added partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 13, 2025
@shewu-quic
Copy link
Collaborator

Hi @michaelk77,

Thanks for trying.
Let me clarify your questions.

  1. Issue with examples.models.llama.export_llama
    We discovered that the model definition in llama_transformer.py isn’t optimal for running the Llama model on the QNN backend. We’ve initiated a new definition in examples/qualcomm/oss_scripts/llama3_2 and are working towards achieving more reasonable output.
  2. Runtime error in examples/qualcomm/oss_scripts/llama3_2
    Could you please provide the command you used to run PTE and specify the device you’re using?
  3. Additional issue
    Thanks for pointing out. We have a PR to address it.

Feel free to let me know if you need any further assistance!

@michaelk77
Copy link

Hi @shewu-quic,

Thank you for your response and clarification!

Runtime Environment:

  • I am running the Llama model using the Android demo app: LlamaDemo.
  • Device: iQOO 10 Pro (8GB/256GB), Snapdragon 8 Gen1 Plus, Android 14.

Updated Status:

  • With the updated main branch, I was able to switch to QNN version 2.28.0.241029.
  • However, the issue persists when running the kv_llama3_2_qnn.pte model:
    • Logcat Error: input 2 is none.
    • Android App Behavior: The app notes that the model responded, but the output is empty.

PTE Generation:

I generated the kv_llama3_2_qnn.pte model using the following command:

python examples/qualcomm/oss_scripts/llama3_2/llama.py \
  -b build-android \
  -m SM8475 \
  --checkpoint "consolidated.00.pth" \
  --params "original_params.json" \
  --ptq 16a4w \
  --model_size 1B \
  --tokenizer_model "tokenizer.model" \
  --prompt "what is 1+1" \
  --temperature 0 \
  --model_mode kv \
  --prefill_seq_len 32 \
  --kv_seq_len 128 \
  --compile_only

Model Source:

I am using the model files from Meta Llama 3.2 1B Instruct on Hugging Face.

If you need additional logs or further details, please let me know. I appreciate your assistance!

@shewu-quic
Copy link
Collaborator

Thanks for your information,

Could you please use the following command to run pte?
Because I think the generated pte from static llama may not be integrated into the demo app yet.

python examples/qualcomm/oss_scripts/llama3_2/llama.py \
  -b build-android \
  -m SM8475 \
  --checkpoint "consolidated.00.pth" \
  --params "original_params.json" \
  --ptq 16a4w \
  --model_size 1B \
  --tokenizer_model "tokenizer.model" \
  --prompt "what is 1+1" \
  --temperature 0 \
  --model_mode kv \
  --prefill_seq_len 32 \
  --kv_seq_len 128 \
  --pre_gen_pte ${path_to_your_pte_directory}

@michaelk77
Copy link

Thank you for providing the command to run the PTE. I have executed the provided command with a minor addition to specify my device using the -s flag:

python examples/qualcomm/oss_scripts/llama3_2/llama.py \
  -b build-android \
  -m SM8475 \
  --checkpoint "consolidated.00.pth" \
  --params "original_params.json" \
  --ptq 16a4w \
  --model_size 1B \
  --tokenizer_model "tokenizer.model" \
  --prompt "what is 1+1" \
  --temperature 0 \
  --model_mode kv \
  --prefill_seq_len 32 \
  --kv_seq_len 128 \
  --pre_gen_pte ${path_to_your_pte_directory} \
  -s # my device code from ADB.

Observations:

  • The model response is extremely unusual and doesn't seem coherent. Here's the output:
<|begin_of_text|><|start_header_id|>user<|end_header_id|>
what is 1+1<|eot_id|><|start_header_id|>assistant<|end_header_id|>

Python
Code
Python
Code
AllYouThER

Python

Code
A
Python

Python

A

Code

Python

**
**

**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
**
Finish the running pre_gen_pte from /home/mihail/executorch/llama3_2_qnn

Log Details:

Here is the relevant portion of the logcat output during execution:

[INFO] [Qnn ExecuTorch]: Deserializing processed data using QnnContextCustomProtocol
[INFO] [Qnn ExecuTorch]: create QNN Logger with log_level 2
[WARNING] [Qnn ExecuTorch]:  <W> Initializing HtpProvider

[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

[INFO] [Qnn ExecuTorch]: Initialize Qnn backend parameters for Qnn executorch backend type 2
[INFO] [Qnn ExecuTorch]: Caching: Caching is in RESTORE MODE.
[INFO] [Qnn ExecuTorch]: QnnContextCustomProtocol expected magic number: 0x5678abcd but get: 0x2000000
[WARNING] [Qnn ExecuTorch]:  <W> Cost Based unsupported on soc SM8475

[INFO] [Qnn ExecuTorch]: Running level=1 optimization.
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend parameters
[INFO] [Qnn ExecuTorch]: Destroy Qnn context
[INFO] [Qnn ExecuTorch]: Destroy Qnn device
[INFO] [Qnn ExecuTorch]: Destroy Qnn backend
[WARNING] [Qnn ExecuTorch]:  <W> Function not called, PrepareLib isn't loaded!

Performance Stats:

The PyTorchObserver logs indicate the following:

  • Prompt Tokens: 16
  • Generated Tokens: 111
  • Total Inference Time: ~360 ms
  • The app notes that the model responded, but the content is gibberish.

Could you let me know if there’s any misconfiguration or additional step I should take? Thank you for your assistance!

@shewu-quic
Copy link
Collaborator

Hi @michaelk77

Sorry for late reply.
We have tested on SM8650 and got more reasonable output.
If possible, could you please try a newer device?
About accuracy issue, we are trying to fix it with QAT.

@jds250
Copy link
Author

jds250 commented Jan 20, 2025

BTW, if you are interested in llama 3.2, we have provided this script to export and run. You can find it here. https://github.com/pytorch/executorch/tree/main/examples/qualcomm/oss_scripts/llama3_2

To enhance user experience, we will integrate our script for Llama as soon as possible.

Thanks for your advice! I have successfully run storiesllama and llama3.2-1B on my device.
I’ve observed some interesting phenomena and would really appreciate your help clarifying them:
my executorch is in commit f370e78

  1. Long model loading time in storiesllama runner
    I noticed that in the storiesllama runner, model loading takes a considerable amount of time. Specifically, the function QnnManager::PreRegisterMem() seems to run for quite a while. Meanwhile, in the llama3 runner, I don’t see the same prolonged loading phase. Could it be that the two runners use very different memory allocation strategies at initialization? What factors contribute to the noticeably longer loading time in storiesllama compared to llama3?

  2. Comparing XNNPack and QNN on a quantized LLaMA3–1B model
    I exported a quantized 1B version of LLaMA3 (command snippet below) and ran inference with both XNNPack and QNN backends:

    python -m examples.models.llama.export_llama \
        --checkpoint <consolidated.00.pth> \
        -p <params.json> \
        -kv \
        --use_sdpa_with_kv_cache \
        -X \
        -qmode 8da4w \
        --group_size 128 \
        --metadata '{"get_bos_id":128000, "get_eos_ids":[128009, 128001]}' \
        --output_name="llama3_kv_sdpa_xnn.pte"

    To my surprise, XNNPack achieved around 65 tokens/s, which was actually faster than QNN’s 60 tokens/s. My understanding is that XNNPack runs purely on the CPU, whereas QNN uses HTP—so in theory, HTP hardware should outperform CPU-based inference. Could there be additional overhead on QNN that outweighs its hardware advantage, or are there other factors at play (e.g., memory layout differences, kernel-level optimizations, etc.)?

Any insights you could share would be very helpful! Thank you in advance for your time and assistance.

@shewu-quic
Copy link
Collaborator

Congratulations on your effort. I’m pleased to help you.

  1. Good point. We also found long loading times in the Stories Llama runner. The function QnnManager::PreRegisterMem() uses ion memory to create a shared buffer to enhance inference performance. Unfortunately, this results in long loading times.
    However, we have created a PR with another method that updates the kv cache to enhance model loading time and inference performance. Following this PR, a document PR describing how to update the kv cache will be submitted soon.

  2. Yes, that makes sense. This is why we created a static llama to get the better performance in examples/qualcomm/oss_scripts/llama which is QNN HTP-friendly llama structure
    I believe the key point is the output of the kv cache. In examples/models/llama, the kv cache is updated through index_put and outputs the entire kv cache, which results in memory pressure.
    However, in our examples/qualcomm/oss_scripts/llama, we only output the updated kv cache and manage the kv cache on the CPU side. This improves overall performance.
    We also advise using graph sharding with --num_sharding. This has the most significant impact on performance. When HTP execution time exceeds RPC polling time, CPU frequency may drop to its minimum. Our experiments indicate that setting num_sharding to 4 yields the best results.
    Now we can get ~47 tokens/sec with SM8650 SoC and context length 4096 for llama 3.2 1B and ~67 tokens/sec with SM8650 SoC, context length 512 and 4 sharding for llama 3.2 1B. I think that using QNN HTP with larger context lengths (such as 4K) should provide greater speed advantages than CPU.
    We have also unified the llama2 and llama3_2 scripts. You could try again after this PR.
    More information you could get from the README

If you have any questions, please don’t hesitate to let me know.

@jds250
Copy link
Author

jds250 commented Jan 22, 2025

Congratulations on your effort. I’m pleased to help you.

  1. Good point. We also found long loading times in the Stories Llama runner. The function QnnManager::PreRegisterMem() uses ion memory to create a shared buffer to enhance inference performance. Unfortunately, this results in long loading times.
    However, we have created a PR with another method that updates the kv cache to enhance model loading time and inference performance. Following this PR, a document PR describing how to update the kv cache will be submitted soon.
  2. Yes, that makes sense. This is why we created a static llama to get the better performance in examples/qualcomm/oss_scripts/llama which is QNN HTP-friendly llama structure
    I believe the key point is the output of the kv cache. In examples/models/llama, the kv cache is updated through index_put and outputs the entire kv cache, which results in memory pressure.
    However, in our examples/qualcomm/oss_scripts/llama, we only output the updated kv cache and manage the kv cache on the CPU side. This improves overall performance.
    We also advise using graph sharding with --num_sharding. This has the most significant impact on performance. When HTP execution time exceeds RPC polling time, CPU frequency may drop to its minimum. Our experiments indicate that setting num_sharding to 4 yields the best results.
    Now we can get ~47 tokens/sec with SM8650 SoC and context length 4096 for llama 3.2 1B and ~67 tokens/sec with SM8650 SoC, context length 512 and 4 sharding for llama 3.2 1B. I think that using QNN HTP with larger context lengths (such as 4K) should provide greater speed advantages than CPU.
    We have also unified the llama2 and llama3_2 scripts. You could try again after this PR.
    More information you could get from the README

If you have any questions, please don’t hesitate to let me know.

Thank you very much for your response! I have a few more questions that I hope you could help me:

  1. I noticed that Qualcomm also provides a tutorial for Llama2, and I’ve read some documentation on the QNN SDK. From the documentation, I understand that you support converting a PyTorch model into the QNN model format (as .cpp and .bin files). I’m curious about the differences between this conversion process and the export process in Executorch. Could you explain the underlying differences between the two methods? Which one would you recommend using? From the model execution perspective, is there any fundamental difference between the .pte file and the QNN model file? My current understanding is that both are compiled into a static computation graph, but I’d appreciate your insights on this.

  2. I think I might be missing some understanding of the complete model inference flow on Qualcomm devices, especially regarding what you mentioned about "we only output the updated kv cache and manage the kv cache on the CPU side." I’m not entirely clear on the relationship between HTP memory management and the CPU (any resources on this topic you could point me to would be greatly appreciated). Even though mobile devices have a unified memory architecture, can we consider the memory for HTP and the CPU to be isolated from each other? So, does this mean that HTP only gets the complete kv cache from the CPU and outputs the newly generated kv values back to the CPU side, which then assembles the updated kv cache? From what I understand in XNNPACK’s implementation, there is no need for full kv cache transfer between the CPU and HTP, so when the context length is short, the computational advantage of HTP is less noticeable?

  3. I noticed that the current implementation includes modes for prefill, kv, and hybrid. Why do we need different implementations for these modes? Is it because the shape of the tensors in the computation graph is different during the prefill and kv modes, and hence different computation graphs are used for different files? Is the hybrid mode simply a combination of the prefill and kv models into one?

I would really appreciate any clarification on these points. If I misunderstood any part, I’d be grateful if you could point that out!

@shewu-quic
Copy link
Collaborator

Hi @jds250,

Of course. Happy to help.

  1. Regarding my recommendation, it depends on your familiarity with QNN SDK or ExecuTorch. We’ve made several optimizations to improve performance in ExecuTorch, so both should be quite similar in terms of performance.
    Export flow of QNN SDK:
    nn.Module --> (some torch infrastructure) --> onnx model --> (qnn-onnx-converter) --> QNN Model --> (QNN SDK) --> context binary
    Export flow of ExecuTorch:
    nn.Module --> pte (including context binary)
    Since QNN SDK doesn’t support PyTorch 2 yet, you need to convert the PyTorch 2 model to ONNX first and then compile the model to a context binary with QNN SDK. The export flow of QNN SDK is a bit more complicated, so I recommend using Qualcomm AI Hub to export context binaries of Llama. You can also run these context binaries on ExecuTorch. Note that pre-compiled context binaries cannot be further fine-tuned for other downstream tasks.
  2. Generally, communication between CPU and NPU uses DMA, which results in overhead depending on the I/O size with NPU. When handling the entire KV cache I/O, the overhead will be larger than when only outputting new KV cache. For XNNPACK, which runs on CPU, I think this overhead does not exist.
  3. As per the document, we have three modes to choose from. We recommend using hybrid mode to get the best overall performance. Initially, it uses prefill mode to efficiently generate the prompt’s key-value (KV) cache. Then, it switches to KV cache mode, which excels at generating subsequent tokens. Additionally, we enable weight sharing in hybrid mode, so the size of the PTE is not doubled.

I hope this helps! Let me know if you need any further assistance.

@jds250
Copy link
Author

jds250 commented Feb 1, 2025

Hi @jds250,

Of course. Happy to help.

  1. Regarding my recommendation, it depends on your familiarity with QNN SDK or ExecuTorch. We’ve made several optimizations to improve performance in ExecuTorch, so both should be quite similar in terms of performance.
    Export flow of QNN SDK:
    nn.Module --> (some torch infrastructure) --> onnx model --> (qnn-onnx-converter) --> QNN Model --> (QNN SDK) --> context binary
    Export flow of ExecuTorch:
    nn.Module --> pte (including context binary)
    Since QNN SDK doesn’t support PyTorch 2 yet, you need to convert the PyTorch 2 model to ONNX first and then compile the model to a context binary with QNN SDK. The export flow of QNN SDK is a bit more complicated, so I recommend using Qualcomm AI Hub to export context binaries of Llama. You can also run these context binaries on ExecuTorch. Note that pre-compiled context binaries cannot be further fine-tuned for other downstream tasks.
  2. Generally, communication between CPU and NPU uses DMA, which results in overhead depending on the I/O size with NPU. When handling the entire KV cache I/O, the overhead will be larger than when only outputting new KV cache. For XNNPACK, which runs on CPU, I think this overhead does not exist.
  3. As per the document, we have three modes to choose from. We recommend using hybrid mode to get the best overall performance. Initially, it uses prefill mode to efficiently generate the prompt’s key-value (KV) cache. Then, it switches to KV cache mode, which excels at generating subsequent tokens. Additionally, we enable weight sharing in hybrid mode, so the size of the PTE is not doubled.

I hope this helps! Let me know if you need any further assistance.

Thank you for your helpful response! I have a few follow-up questions regarding some details:

  1. Weight Sharing Logic: I was wondering where exactly the weight sharing logic is implemented. I have reviewed the entire process in llama.py during the compilation step, but it seems that for both prefill and kv modes, the operations such as quantization and export are executed as two separate computation graphs. I couldn't find an explicit logic for weight sharing until the very end, where the graphs are serialized into a PTE file. BTW I noticed that the context binary created here and the one downloaded from the QNN AI hub both use the from_context_binary logic. Does this mean they are equivalent?

  2. Hybrid Computation Graphs: I'm curious whether, when the prefill and kv computation graphs are used in hybrid mode, any nodes in the graph (aside from weight parameters) are also shared? In other words, while reading through static_llama.py, I noticed that the only difference between the two modes is how attention is calculated. Does this imply that, except for the attention calculations, the rest of the computation graph can be shared between these two modes?

  3. KV Cache Memory Management in NPU: Lastly, I'd like to understand how the KV cache memory management works on the NPU side. Could you point me to the code or logic that handles this? I want to know if the overhead on the I/O size with NPU is a bottleneck for on-device inference.

I appreciate your time and help!


@shewu-quic
Copy link
Collaborator

shewu-quic commented Feb 3, 2025

  1. Weight Sharing Logic: It will be used for hybrid mode to share weight between prefill model and kv model.. To enable weight sharing, we need to enable multi_graph first to set this config which allows common weights across graphs to be shared and stored in a single context binary.. And then we need to put all the graphs which we want to share weight with into one context binary.. It will help in reducing the overall memory usage by sharing common weights. Get more details from this PR.
    The function from_context_binary is to build a custom op which loads prebuilt context binary.

  2. The weights are shared only between the prefill model and the KV model. From the QNN perspective, they are two separate graphs within a single context binary such as the below figure. During runtime, we will specify the graph name to run either the prefill graph or the KV graph.
    Image

  3. This process is a bit complex. We’ll create a pull request to add a document detailing how to manage the KV cache. I’ll notify you as soon as we have any updates.

I hope this clarifies things! Feel free to reach out if you need any more help.

@jds250
Copy link
Author

jds250 commented Feb 5, 2025

  1. Weight Sharing Logic: It will be used for hybrid mode to share weight between prefill model and kv model.. To enable weight sharing, we need to enable multi_graph first to set this config which allows common weights across graphs to be shared and stored in a single context binary.. And then we need to put all the graphs which we want to share weight with into one context binary.. It will help in reducing the overall memory usage by sharing common weights. Get more details from this PR.
    The function from_context_binary is to build a custom op which loads prebuilt context binary.
  2. The weights are shared only between the prefill model and the KV model. From the QNN perspective, they are two separate graphs within a single context binary such as the below figure. During runtime, we will specify the graph name to run either the prefill graph or the KV graph.
    Image
  3. This process is a bit complex. We’ll create a pull request to add a document detailing how to manage the KV cache. I’ll notify you as soon as we have any updates.

I hope this clarifies things! Feel free to reach out if you need any more help.

Thank you for your detailed reply! I'm really looking forward to your upcoming documentation on kv cache management.

In the meantime, I'd like to explore custom quantization algorithms to deploy a model on an NPU. I've noticed that Qualcomm provides the AIMET toolkit for quantization, and I'm also considering using executorch to implement a custom quantization flow. Could you share any guidance or examples on how to implement a custom quantization approach for NPU deployment? Specifically, I'm wondering if we need to implement custom quantization operators tailored to the NPU or if there's existing support we can leverage.

Any advice or references you could point me to would be incredibly helpful. Thanks again for all your assistance!

@haowhsu-quic
Copy link
Collaborator

Hi @jds250,
I think KV cache management documentation is already there, please take a look at this.
For custom quantization algorithms, there is an example about hooking custom PTQ observer. The corresponding use case is here.

@jds250 jds250 closed this as completed Feb 20, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
partner: qualcomm For backend delegation, kernels, demo, etc. from the 3rd-party partner, Qualcomm triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

5 participants