Skip to content

[Reland] Qualcomm AI Engine Direct - Unify Llama2&Llama3 and Small Accuracy Im… #8004

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions backends/qualcomm/_passes/insert_requantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,9 @@ def _single_output_annotation(
requantize_dict = n.meta.pop(QCOM_REQUANTIZE)
# {quant_attr: user_node_name_list}
group_quant_attr_dict = self._invert_dict(requantize_dict)
# TODO: If users of the node contain output node,
# we replace the node with to_copy op. However, it would
# be problem when the node has multiple to_copy ops
add_output = len(group_quant_attr_dict) == 1

for hashable_quant_attr, user_nodes in group_quant_attr_dict.items():
user_nodes_copy = user_nodes.copy()
if add_output:
user_nodes_copy.append("output")
self._insert_to_copy(gm, n, dict(hashable_quant_attr), user_nodes_copy)

def _insert(self, graph_module: torch.fx.GraphModule) -> torch.fx.GraphModule:
Expand Down
74 changes: 68 additions & 6 deletions backends/qualcomm/quantizer/custom_annotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,80 @@
QuantizationConfig,
)
from executorch.exir.dialects._ops import ops as exir_ops
from torch.ao.quantization.observer import MinMaxObserver
from torch.ao.quantization.observer import FixedQParamsObserver, MinMaxObserver
from torch.ao.quantization.quantizer import (
QuantizationAnnotation,
QuantizationSpec,
SharedQuantizationSpec,
)
from torch.fx import Node


def annotate_matmul_16a8w( # noqa: C901
gm: torch.fx.GraphModule, traverse_input1=True
) -> None:
def annotate_linear_16a8w_in_affine_layer(gm: torch.fx.GraphModule) -> None:
def annotate_conv2d(node: Node, quantization_config: QuantizationConfig) -> None:
input_qspec_map = {}
input_act = node.args[0]
input_spec = quantization_config.input_activation
input_qspec_map[input_act] = input_spec

weight = node.args[1]
input_qspec_map[weight] = quantization_config.weight

node.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=quantization_config.output_activation,
_annotated=True,
)

quantization_config_16a8w_per_channel = get_ptq_per_channel_quant_config(
torch.uint16, weight_dtype=torch.int8, act_observer=MinMaxObserver
)
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.conv2d.default:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
if full_qualified_name == "output.conv":
annotate_conv2d(
node, quantization_config=quantization_config_16a8w_per_channel
)


def annotate_prefill_kv_output(gm: torch.fx.GraphModule, kv_quant_attrs: dict):
for node in gm.graph.nodes:
if node.op == "output":
for index, prefill_output in enumerate(node.args[0]):
kv_quant_attr = kv_quant_attrs[index]
fixed_observer = FixedQParamsObserver.with_args(
scale=kv_quant_attr[0],
zero_point=kv_quant_attr[1],
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
qscheme=torch.torch.per_tensor_affine,
)

fixed_output_spec = QuantizationSpec(
quant_min=kv_quant_attr[2],
quant_max=kv_quant_attr[3],
dtype=kv_quant_attr[4],
ch_axis=0,
observer_or_fake_quant_ctr=fixed_observer,
)

input_qspec_map = {}
for input in prefill_output.args:
if isinstance(input, Node):
input_qspec_map[input] = fixed_output_spec

prefill_output.meta[QUANT_ANNOTATION_KEY] = QuantizationAnnotation(
input_qspec_map=input_qspec_map,
output_qspec=fixed_output_spec,
_annotated=True,
)


def annotate_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
"""
This function is specific for matmul op 16a8w.
For k, we will tag such as the below, and
Expand Down Expand Up @@ -142,8 +205,7 @@ def annotate_matmul_input1(node: Node):
for node in gm.graph.nodes:
if node.op == "call_function" and node.target == torch.ops.aten.matmul.default:
annotate_matmul(node, quantization_config_16a8w)
if traverse_input1:
annotate_matmul_input1(node.args[1])
annotate_matmul_input1(node.args[1])


def custom_annotate_llama_matmul_16a8w(gm: torch.fx.GraphModule) -> None: # noqa: C901
Expand Down
4 changes: 3 additions & 1 deletion backends/qualcomm/tests/test_qnn_delegate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3529,7 +3529,7 @@ def test_stories_single_llama(self):

cmds = [
"python",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama2/llama.py",
f"{self.executorch_root}/examples/qualcomm/oss_scripts/llama/llama.py",
"--artifact",
self.artifact_dir,
"--build_folder",
Expand All @@ -3556,6 +3556,8 @@ def test_stories_single_llama(self):
"16a4w",
"--temperature",
"0",
"--llama_model",
"stories110m",
]
if self.host:
cmds.extend(["--host", self.host])
Expand Down
7 changes: 2 additions & 5 deletions examples/qualcomm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,8 @@ set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag})
# build qnn_executor_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/executor_runner)

# build qnn_llama_runner for llama2
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama2)

# build qnn_llama_runner for llama3.2
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama3_2)
# build qnn_llama_runner for llama
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/oss_scripts/llama)

# build qaihub_llama2_7b_runner and qaihub_llama3_8b_runner
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/qaihub_scripts/llama)
Expand Down
4 changes: 2 additions & 2 deletions examples/qualcomm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ This directory contains examples for some AI models.

We have seperated the example scripts into the following subfolders, please refer to [README.md](../../backends/qualcomm/README.md) for the example scripts' directory structure:

1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.
1. executor_runner: This folder contains a general executor runner capable of running most of the models. As a rule of thumb, if a model does not have its own customized runner, execute the model using [executor_runner](./executor_runner/qnn_executor_runner.cpp). On the other hand, if a model has its own runner, such as [llama](./oss_scripts/llama/qnn_llama_runner.cpp), use the customized runner to execute the model. Customized runner should be located under the same folder as the model's python script.

2. oss_scripts: OSS stands for Open Source Software. This folder contains python scripts for open source models. Some models under this folder might also have their own customized runner.
For example, [llama2](./oss_scripts/llama2/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.
For example, [llama](./oss_scripts/llama/qnn_llama_runner.cpp) contains not only the python scripts to prepare the model but also a customized runner for executing the model.

3. qaihub_scripts: QAIHub stands for [Qualcomm AI Hub](https://aihub.qualcomm.com/). On QAIHub, users can find pre-compiled context binaries, a format used by QNN to save its models. This provides users with a new option for model deployment. Different from oss_scripts & scripts, which the example scripts are converting a model from nn.Module to ExecuTorch .pte files, qaihub_scripts provides example scripts for converting pre-compiled context binaries to ExecuTorch .pte files. Additionaly, users can find customized example runners specific to the QAIHub models for execution. For example [qaihub_llama2_7b](./qaihub_scripts/llama2/qaihub_llama2_7b.py) is a script converting context binaries to ExecuTorch .pte files, and [qaihub_llama2_7b_runner](./qaihub_scripts/llama2/qaihub_llama2_7b_runner.cpp) is a customized example runner to execute llama2 .pte files. Please be aware that context-binaries downloaded from QAIHub are tied to a specific QNN SDK version.
Before executing the scripts and runner, please ensure that you are using the QNN SDK version that is matching the context binary. Please refer to [Check context binary version](#check-context-binary-version) for tutorial on how to check the QNN Version for a context binary.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,38 +18,35 @@ target_link_libraries(
)
target_link_options_shared_lib(custom_ops)

# preprocess qnn runner src files for llama3.2
set(_llama3_2_runner__srcs ${_llama_runner__srcs})
list(TRANSFORM _llama3_2_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
list(FILTER _llama3_2_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
# preprocess qnn runner src files for llama
set(_llama_runner__srcs ${_llama_runner__srcs})
list(TRANSFORM _llama_runner__srcs PREPEND "${EXECUTORCH_SOURCE_DIR}/")
list(FILTER _llama_runner__srcs EXCLUDE REGEX ".*(/runner/).*")
list(
PREPEND
_llama3_2_runner__srcs
${CMAKE_CURRENT_LIST_DIR}/qnn_llama3_2_runner.cpp
_llama_runner__srcs
${CMAKE_CURRENT_LIST_DIR}/qnn_llama_runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/runner.h
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.cpp
${CMAKE_CURRENT_LIST_DIR}/runner/io_memory.h
)

list(
APPEND _llama3_2_runner__srcs
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
)
list(
APPEND
_llama3_2_runner__srcs
_llama_runner__srcs
${CMAKE_CURRENT_SOURCE_DIR}/../../../../extension/llm/tokenizer/tiktoken.cpp
${CMAKE_CURRENT_SOURCE_DIR}/../../../models/llama/tokenizer/llama_tiktoken.cpp
)

# build qnn llama3.2 1b runner
add_executable(qnn_llama3_2_runner ${_llama3_2_runner__srcs})
# build qnn llama runner
add_executable(qnn_llama_runner ${_llama_runner__srcs})
target_include_directories(
qnn_llama3_2_runner PUBLIC ${_common_include_directories}
qnn_llama_runner PUBLIC ${_common_include_directories}
)

target_link_libraries(
qnn_llama3_2_runner
qnn_llama_runner
qnn_executorch_backend
executorch_core
extension_data_loader
Expand All @@ -60,8 +57,8 @@ target_link_libraries(
custom_ops
)
target_compile_options(
qnn_llama3_2_runner PUBLIC ${_common_compile_options}
qnn_llama_runner PUBLIC ${_common_compile_options}
)
set_target_properties(
qnn_llama3_2_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
qnn_llama_runner PROPERTIES LINK_FLAGS "-Wl,-rpath='$ORIGIN'"
)
70 changes: 70 additions & 0 deletions examples/qualcomm/oss_scripts/llama/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Summary

## Overview
This file provides you the instructions to run LLAMA model with different parameters via Qualcomm HTP backend. We currently support the following models:
1. LLAMA2 Stories 110M
2. LLAMA3.2 1B
3. LLAMA3.2 3B (WIP)
We offer the following modes to execute the model:

Prefill Mode: This is also known as batch prefill mode, where the model takes in a list of tokens as input and generates the next token along with the key-value (KV) cache for all tokens. This mode is efficient for generating the initial sequence of tokens (usually the user's prompt).

KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt.

Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens.


## Instructions
### Note
1. For hybrid mode, the export time will be longer and can take up to 1-4 hours to complete, depending on the specific model users are exporting.
2. When exporting a hybrid mode model, memory consumption will be higher. Taking LLAMA3.2 1B as an example, please ensure the device has at least 80 GB of memory and swap space.


### Step 1: Setup
1. Follow the [tutorial](https://pytorch.org/executorch/main/getting-started-setup) to set up ExecuTorch.
2. Follow the [tutorial](https://pytorch.org/executorch/stable/build-run-qualcomm-ai-engine-direct-backend.html) to build Qualcomm AI Engine Direct Backend.

### Step 2: Prepare Model

#### LLAMA2
Download and prepare stories110M model

```bash
# tokenizer.model & stories110M.pt:
wget "https://huggingface.co/karpathy/tinyllamas/resolve/main/stories110M.pt"
wget "https://raw.githubusercontent.com/karpathy/llama2.c/master/tokenizer.model"

# tokenizer.bin:
python -m extension.llm.tokenizer.tokenizer -t tokenizer.model -o tokenizer.bin

# params.json:
echo '{"dim": 768, "multiple_of": 32, "n_heads": 12, "n_layers": 12, "norm_eps": 1e-05, "vocab_size": 32000}' > params.json
```

#### LLAMA3.2
Follow the [instructions](https://www.llama.com/) to download models.
At the end of this step, users should have the following files ready: `consolidated.00.pth`, `params.json`, and `tokenizer.model`.


### Step3: Run default examples using hybrid mode.
#### LLAMA2
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint stories110M.pt --params params.json --tokenizer_model tokenizer.model --tokenizer_bin tokenizer.bin --llama_model stories110m --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "Once upon a time"
```

#### LLAMA3.2
Default example using hybrid mode.
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1"
```

### Additional Configs when running the script
If you would like to compile the model only, we have provided the flag `--compile_only`. Taking LLAMA3.2 as an example:
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --compile_only
```

On the other hand, if you already have a pre-compiled .pte model, you can perform inference by providing the flag `--pre_gen_pte` and specifying the folder that contains the .pte model. Taking LLAMA3.2 as an example:
```bash
python examples/qualcomm/oss_scripts/llama/llama.py -b build-android -s ${SERIAL_NUM} -m ${SOC_MODEL} --ptq 16a4w --checkpoint consolidated.00.pth --params params.json --tokenizer_model tokenizer.model --llama_model llama3_2 --model_mode hybrid --prefill_seq_len 32 --kv_seq_len 128 --prompt "what is 1+1" --pre_gen_pte ${FOLDER_TO_PRE_GEN_PTE}
```
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ load("@fbsource//xplat/executorch/build:runtime_wrapper.bzl", "runtime")

oncall("executorch")


python_library(
name = "static_llama",
srcs = [
Expand All @@ -19,9 +18,12 @@ python_library(
python_binary(
name = "llama",
srcs = ["llama.py"],
main_function = "executorch.examples.qualcomm.oss_scripts.llama2.llama.main",
main_function = "executorch.examples.qualcomm.oss_scripts.llama.llama.main",
preload_deps = [
"//executorch/extension/llm/custom_ops:model_sharding_py",
],
deps = [
":static_llama",
"//executorch/examples/qualcomm/oss_scripts/llama:static_llama",
"//caffe2:torch",
"//executorch/extension/pybindings:aten_lib",
"//executorch/backends/qualcomm/partition:partition",
Expand All @@ -38,6 +40,8 @@ runtime.command_alias(
name = "llama_qnn",
env = {
"LD_LIBRARY_PATH": "$(location fbsource//third-party/qualcomm/qnn/qnn-{0}:qnn_offline_compile_libs)".format(get_qnn_library_verision()),
# Place holder to pass the QNN_SDK_ROOT check in executorch/examples/qualcomm/utils.py
"QNN_SDK_ROOT": "",
},
exe = ":llama",
)
Loading