Skip to content

Commit 2015537

Browse files
winskuo-quiccccclai
authored andcommitted
Qualcomm AI Engine Direct - Unify Llama2&Llama3 and Small Accuracy Improvement. (#7618)
Qualcomm AI Engine Direct - Unify Llama2 and Llama3
1 parent 3540723 commit 2015537

24 files changed

+342
-2022
lines changed

backends/qualcomm/_passes/insert_requantize.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,15 +89,9 @@ def _single_output_annotation(
8989
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
9090
# {quant_attr: user_node_name_list}
9191
group_quant_attr_dict = self._invert_dict(requantize_dict)
92-
# TODO: If users of the node contain output node,
93-
# we replace the node with to_copy op. However, it would
94-
# be problem when the node has multiple to_copy ops
95-
add_output = len(group_quant_attr_dict) == 1
9692

9793
for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
9894
user_nodes_copy = user_nodes.copy()
99-
if add_output:
100-
user_nodes_copy.append("output")
10195
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)
10296

10397
def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:

backends/qualcomm/quantizer/custom_annotation.py

Lines changed: 68 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,80 @@
1414
QuantizationConfig,
1515
)
1616
from executorch.exir.dialects._ops import ops as exir_ops
17-
from torch.ao.quantization.observer import MinMaxObserver
17+
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
1818
from torch.ao.quantization.quantizer import (
1919
QuantizationAnnotation,
20+
QuantizationSpec,
2021
SharedQuantizationSpec,
2122
)
2223
from torch.fx import Node
2324

2425

25-
def annotate_matmul_16a8w( # noqa: C901
26-
gm: torch.fx.GraphModule, traverse_input1=True
27-
) -> None:
26+
def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
27+
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
28+
input_qspec_map = {}
29+
input_act = node.args[0]
30+
input_spec = quantization_config.input_activation
31+
input_qspec_map[input_act] = input_spec
32+
33+
weight = node.args[1]
34+
input_qspec_map[weight] = quantization_config.weight
35+
36+
node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
37+
input_qspec_map=input_qspec_map,
38+
output_qspec=quantization_config.output_activation,
39+
_annotated=True,
40+
)
41+
42+
quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
43+
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
44+
)
45+
for node in gm.graph.nodes:
46+
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
47+
if "nn_module_stack" in node.meta:
48+
module_values_list = list(node.meta["nn_module_stack"].values())
49+
full_qualified_name = module_values_list[-1][0]
50+
if full_qualified_name == "output.conv":
51+
annotate_conv2d(
52+
node, quantization_config=quantization_config_16a8w_per_channel
53+
)
54+
55+
56+
def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
57+
for node in gm.graph.nodes:
58+
if node.op == "output":
59+
for index, prefill_output in enumerate(node.args[0]):
60+
kv_quant_attr = kv_quant_attrs[index]
61+
fixed_observer = FixedQParamsObserver.with_args(
62+
scale=kv_quant_attr[0],
63+
zero_point=kv_quant_attr[1],
64+
quant_min=kv_quant_attr[2],
65+
quant_max=kv_quant_attr[3],
66+
dtype=kv_quant_attr[4],
67+
qscheme=torch.torch.per_tensor_affine,
68+
)
69+
70+
fixed_output_spec = QuantizationSpec(
71+
quant_min=kv_quant_attr[2],
72+
quant_max=kv_quant_attr[3],
73+
dtype=kv_quant_attr[4],
74+
ch_axis=0,
75+
observer_or_fake_quant_ctr=fixed_observer,
76+
)
77+
78+
input_qspec_map = {}
79+
for input in prefill_output.args:
80+
if isinstance(input, Node):
81+
input_qspec_map[input] = fixed_output_spec
82+
83+
prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
84+
input_qspec_map=input_qspec_map,
85+
output_qspec=fixed_output_spec,
86+
_annotated=True,
87+
)
88+
89+
90+
def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
2891
"""
2992
This function is specific for matmul op 16a8w.
3093
For k, we will tag such as the below, and
@@ -142,8 +205,7 @@ def annotate_matmul_input1(node: Node):
142205
for node in gm.graph.nodes:
143206
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
144207
annotate_matmul(node, quantization_config_16a8w)
145-
if traverse_input1:
146-
annotate_matmul_input1(node.args[1])
208+
annotate_matmul_input1(node.args[1])
147209

148210

149211
def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901

backends/qualcomm/tests/test_qnn_delegate.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3529,7 +3529,7 @@ def test_stories_single_llama(self):
35293529

35303530
cmds = [
35313531
"python",
3532-
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
3532+
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
35333533
"--artifact",
35343534
self.artifact_dir,
35353535
"--build_folder",
@@ -3556,6 +3556,8 @@ def test_stories_single_llama(self):
35563556
"16a4w",
35573557
"--temperature",
35583558
"0",
3559+
"--llama_model",
3560+
"stories110m",
35593561
]
35603562
if self.host:
35613563
cmds.extend(["--host", self.host])

examples/qualcomm/CMakeLists.txt

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,8 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
8484
# build qnn_executor_runner
8585
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)
8686

87-
# build qnn_llama_runner for llama2
88-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)
89-
90-
# build qnn_llama_runner for llama3.2
91-
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
87+
# build qnn_llama_runner for llama
88+
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)
9289

9390
# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
9491
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)

examples/qualcomm/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ This directory contains examples for some AI models.
44

55
We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:
66

7-
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
7+
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
88

99
2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
10-
For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
10+
For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
1111

1212
3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
1313
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Please refer to [Check context binary version](#check-context-binary-version) for tutorial on how to check the QNN Version for a context binary.

examples/qualcomm/oss_scripts/llama3_2/CMakeLists.txt renamed to examples/qualcomm/oss_scripts/llama/CMakeLists.txt

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,38 +18,35 @@ target_link_libraries(
1818
)
1919
target_link_options_shared_lib(custom_ops)
2020

21-
# preprocess qnn runner src files for llama3.2
22-
set(_llama3_2_runner__srcs ${_llama_runner__srcs})
23-
list(TRANSFORM _llama3_2_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
24-
list(FILTER _llama3_2_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
21+
# preprocess qnn runner src files for llama
22+
set(_llama_runner__srcs ${_llama_runner__srcs})
23+
list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
24+
list(FILTER _llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
2525
list(
2626
PREPEND
27-
_llama3_2_runner__srcs
28-
${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
27+
_llama_runner__srcs
28+
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
2929
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
3030
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
3131
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
3232
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
3333
)
3434

35-
list(
36-
APPEND _llama3_2_runner__srcs
37-
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
38-
)
3935
list(
4036
APPEND
41-
_llama3_2_runner__srcs
37+
_llama_runner__srcs
38+
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
4239
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp
4340
)
4441

45-
# build qnn llama3.2 1b runner
46-
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
42+
# build qnn llama runner
43+
add_executable(qnn_llama_runner ${_llama_runner__srcs})
4744
target_include_directories(
48-
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
45+
qnn_llama_runner PUBLIC ${_common_include_directories}
4946
)
5047

5148
target_link_libraries(
52-
qnn_llama3_2_runner
49+
qnn_llama_runner
5350
qnn_executorch_backend
5451
executorch_core
5552
extension_data_loader
@@ -60,8 +57,8 @@ target_link_libraries(
6057
custom_ops
6158
)
6259
target_compile_options(
63-
qnn_llama3_2_runner PUBLIC ${_common_compile_options}
60+
qnn_llama_runner PUBLIC ${_common_compile_options}
6461
)
6562
set_target_properties(
66-
qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
63+
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
6764
)
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Summary
2+
3+
## Overview
4+
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
5+
1. LLAMA2 Stories 110M
6+
2. LLAMA3.2 1B
7+
3. LLAMA3.2 3B (WIP)
8+
We offer the following modes to execute the model:
9+
10+
Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt).
11+
12+
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.
13+
14+
Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.
15+
16+
17+
## Instructions
18+
### Note
19+
1. For hybrid mode, the export time will be longer and can take up to 1-4 hours to complete, depending on the specific model users are exporting.
20+
2. When exporting a hybrid mode model, memory consumption will be higher. Taking LLAMA3.2 1B as an example, please ensure the device has at least 80 GB of memory and swap space.
21+
22+
23+
### Step 1: Setup
24+
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
25+
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.
26+
27+
### Step 2: Prepare Model
28+
29+
#### LLAMA2
30+
Download and prepare stories110M model
31+
32+
```bash
33+
# tokenizer.model & stories110M.pt:
34+
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
35+
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"
36+
37+
# tokenizer.bin:
38+
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin
39+
40+
# params.json:
41+
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
42+
```
43+
44+
#### LLAMA3.2
45+
Follow the [instructions](https://www.llama.com/) to download models.
46+
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.
47+
48+
49+
### Step3: Run default examples using hybrid mode.
50+
#### LLAMA2
51+
```bash
52+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time"
53+
```
54+
55+
#### LLAMA3.2
56+
Default example using hybrid mode.
57+
```bash
58+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
59+
```
60+
61+
### Additional Configs when running the script
62+
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
63+
```bash
64+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only
65+
```
66+
67+
On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
68+
```bash
69+
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
70+
```

examples/qualcomm/oss_scripts/llama3_2/TARGETS renamed to examples/qualcomm/oss_scripts/llama/TARGETS

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,25 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")
55

66
oncall("executorch")
77

8+
python_library(
9+
name = "static_llama",
10+
srcs = [
11+
"model/static_llama.py",
12+
],
13+
deps = [
14+
"//caffe2:torch",
15+
],
16+
)
17+
818
python_binary(
919
name = "llama",
1020
srcs = ["llama.py"],
11-
main_function = "executorch.examples.qualcomm.oss_scripts.llama3_2.llama.main",
21+
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
1222
preload_deps = [
1323
"//executorch/extension/llm/custom_ops:model_sharding_py",
1424
],
1525
deps = [
16-
"//executorch/examples/qualcomm/oss_scripts/llama2:static_llama",
26+
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
1727
"//caffe2:torch",
1828
"//executorch/extension/pybindings:aten_lib",
1929
"//executorch/backends/qualcomm/partition:partition",

0 commit comments

Comments
 (0)