-
Notifications
You must be signed in to change notification settings - Fork 537
[Draft] Qualcomm AI Engine Direct -Enable story llama model in quantied and fp #4030
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Draft] Qualcomm AI Engine Direct -Enable story llama model in quantied and fp #4030
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/4030
Note: Links to docs will display an error until the docs builds have been completed. ❌ 7 New FailuresAs of commit e68e225 with merge base de300e0 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@shewu-quic great job! does it support llama2 7b? |
Unfortunately, it does not support llama2 7b in this draft, but we are actively working on enabling llama2 7b. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is great :) I have some questions and would like to understand the motivation behind the changes. Thanks in advance!
@@ -266,6 +277,12 @@ class OpResizeNearestNeighbor: | |||
param_half_pixel_centers: str = "half_pixel_centers" | |||
|
|||
|
|||
@dataclass(init=False, frozen=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it used for index_put?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I choose Qnn ScatterND op to implement index put for llama use case. Because I have no idea to generate index_tensor with Qnn ScatterElements op.
|
||
class FuseConsecutiveTranspose(ExportPass): | ||
""" | ||
This pass fuses consecutive transpose / permute into one to reduce runtime |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I notice that the view_copy
node before/after linear stays there, is there any specific reason we keep them?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need it because keep_dims is not supported for linear op in Qnn Htp.
@@ -248,12 +248,12 @@ Error Runner::generate( | |||
"Sequence length exceeded - please increase the seq_len value passed to generate()"); | |||
|
|||
// start the main loop | |||
int64_t pos = 0; // position in the sequence | |||
int32_t pos = 0; // position in the sequence |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason we cast from int64_t
to int32_t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because we don't support int64 well in Qnn HTP such as the index tensor of ScatterND op.
@@ -107,6 +108,47 @@ def forward( | |||
return y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) | |||
|
|||
|
|||
class SDPAQNN(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can rename it to something else, it's another sdpa replacement, not necessarily QNN specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds great:)
def replace_causal_mask(module: torch.nn.Module): | ||
for buffer_fqn_name, buffer in module.named_buffers(): | ||
buffer_name = buffer_fqn_name.split(".")[-1] | ||
if buffer_name == "mask": | ||
max_seq_len = buffer.shape[-1] | ||
mask = torch.full( | ||
(max_seq_len, max_seq_len), | ||
float("-inf"), | ||
float("-255"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any specific reason we replace inf
with -255
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Acutually, we have a pass to replace inf with min or max because inf is not friendly for quantization or computation in Qnn Htp. It could result in numerical error.
Another challenge we need to conquer is model sharding. |
Actually I have a version to support model sharding and can share the example code |
Hi @cccclai, The accuracy issue seems to be related to insufficient calibration.
|
Ah yes we will use a more generic to calibrate. I merged this pr (#3756) such that we can use the lm eval to calibrate the model |
May I know how you shard the model?
Thanks for your information. Will it use to export_llama_lib? |
Sorry for the delay, was distracted by the performance review last week...I use the ExecutorBackend, and tag every 8 layers, will publish soon. I think having a noop op (maybe a custom op instead of clone because clone can also be expensive) for cutting the model can also be a generic way to shard model too. |
This is my current change, still trying to debug an op but it's getting close.. I think it still worth exploring the custom noop solution to break the graph. What is your preference? |
Wow, it makes me clear how to run the sharding model at runtime.
I think it is a good idea. # custom_fallback_op.py
from torch.library import impl, Library
fallback_op_lib = Library("qnn_llama", "DEF")
fallback_op_lib.define("fallback(Tensor input) -> Tensor")
@impl(fallback_op_lib, "fallback", dispatch_key="CompositeExplicitAutograd")
def fallback_impl(a: torch.Tensor) -> torch.Tensor:
return a
# registering the out variant.
fallback_op_lib.define(
"fallback.out(Tensor input, *, Tensor(a!) output) -> Tensor(a!)"
) # split_graph.py
class SplitGraph(ExportPass):
def __init__(self, shares):
super().__init__()
self.shares = shares
def _insert_fallback_op(
self, graph_module: torch.fx.GraphModule
) -> torch.fx.GraphModule:
for node in graph_module.graph.nodes:
if "nn_module_stack" in node.meta:
module_values_list = list(node.meta["nn_module_stack"].values())
full_qualified_name = module_values_list[-1][0]
owning_module = module_values_list[-1][1]
print(f"[Hutton] node: {node}; full_qualified_name: {full_qualified_name}; owning_module: {owning_module}; meta: {node.meta}")
# if node not in [the node which wants to find]:
# continue
with graph_module.graph.inserting_after(node):
users = list(node.users.keys())
inserted_node = graph_module.graph.create_node(
"call_function",
exir_ops.edge.qnn_llama.fallback.default,
(node,),
)
inserted_node.meta["val"] = node.meta["val"]
for user in users:
user.replace_input_with(node, inserted_node)
def call(self, graph_module: torch.fx.GraphModule):
self._insert_fallback_op(graph_module)
graph_module.recompile() |
This is great. I think if we have a custom graph break op, it doesn't have to qnn specific and can be applicable to other flow or backends.
Like where to insert this custom op inside the graph? I feel like we can find the last node of 8 layer based on source_fn and module stack. Is it not working? Another question is, I image we need to unload the qnn context binary in the graph break custom op. Is it what you're doing? Also the patch is pretty much the idea. There is a bug I need to fix before it's working properly...I'll send another patch soon |
Sounds great.
I originally thought so too but I found it will get multiple nodes in the same layer.
So maybe I also need stack_trace to identify which node we want. Is it stable?
Do you mean we need to handle the life cycle of the processed in the custom op?
Thanks a lot. |
This PR is somehow based on #4142 We will continue llama2-7b tasks by this PR. |
hmm was thinking if finding the last add node for the current layer sufficient, but maybe I miss something. Combing
Yeah that's my understanding too. However for 4 shards, we need to init(shard_1) -> destroy (shard_1) -> init(shard_2)-> destroy (shard_2) -> ..., if we do init(shard_1) -> init(shard_2) -> init(shard_3) -> init(shard_4) -> ... -> destroy (shard_1) ... -> destroy(shard_4), it will OOM in dsp? |
I think it will not OOM if use mult-context feature because I could run composite llama on the device. |
Do you mean you were able to use multi-context for the 7b model 😮 To my understanding, the multi context means multiple graphs in the qnn context binary. How does it work with 4 shards (4 set of graphs) in this case? |
It works with multiple pte case. If we want to enable multi-context, we just need to set the right group handle for each pte which is the first context handle. For the purpose, we use a static variable to accomplish it And we need to set max_sf_buf_size which is the size of blob in AOT. |
I was checking the doc
To my understanding the spill-fill is used for intermediate tensors among the splits. Like the split_1 -> output (in spill-fill) -> split_2. It's for the input/output like activation, but I'm not sure if it will do any optimization for weights. Did I miss anything? |
According to your description, it should be shared buffer (zero copy) which can eliminate data copy between multi ptes on the CPU and HTP accelerator. It's for the input/output of the graph. Spill-fill buffer sharing is optimization which is to allocate a buffer that will be shared by all the contexts of a LLM. This way, we do not need allocated space for each of the graphs. |
That's my understanding too and I thought it was for re-using the input/output across all splits in VTCM, but not for weights across all splits. Like ..act_1 -> split_1 -> act_2 -> split_2 -> act_3 -> split_ 4... here act_1, act_2 and act_3 will share the same buffer, as known as the spill tensor buffer here. |
Hey I probably need some help to fix a matmul validation error - it causes graph break but I'm not sure what's the issue. It only shows up after I apply the model sharding patch, but the graph inside the qnn_partitioner is supposed to be the same as the first layer of the graph. I debug inside
For both op validation success and failure cases, the input nodes are exactly the same. The first |
I feel we are misaligned on some terms. Shared buffer (Zero copy)The propose is to avoid data copy between CPU and HTP. In addition, we could create a bigger rpc memory to be stored act_1, act_2,... etc. We have implemented in our llama2. It will create a rpc memeory to be stored for all input and output and just set the correct offset for each I/O tensor. Spill-Fill bufferVTCM space for each of the SoC is limited hence, when we need to make space within this region, the data is copied back to DDR (spill-fill buffer in this case). Therefore, we allocate one spill buffer for the intermediate tensors in graph (split). VTCMIt is a hardware resource which provides fast store and load. It is controlled by HTP and we could only set the maximum usage for VTCM. So back to your example, act_1, act_2 and act_3 (I/O tensor) will share the same buffer which is rpc memory instead of spill tensor buffer. For intermediate tensor in each graph (split), they will use a spill-fill buffer. ...act_1 (rpc_mem) -> split_1_1 -> intermediate tensor_0 ... (spill fill buffer)-> split_1_2 -> act_2 (rpc_mem) -> split_2 (spill fill buffer -> act_3 (rpc_mem) -> split_ 4 (spill fill buffer... |
May I know which version of QNN are you using? |
If you use quantization, I think the problem is missing quant attr or something wrong for quant parameter in meta of node. Could you help to check it? |
I'm using qnn 2.23 and the matmul node meta data is
fail case
They look very similar....how would you debug next? |
Hmm I think llama_main has the same error. I try it a bit more, and looks like dummy llama params work, but story params doesn’t. Can it be related to the check point data type? |
The error usually means invalid arguments in input or output tensors when calling |
hmm can I confirm with you the latency perf for stories and llama2 in the latest commit? Just would like to make sure we start from the same place |
It seems Hutton also pushed the annotation of 16a8w matmul... so we might see storiesllama 600 tokens/second and llama2 15 tokens/second if I remember correctly... mmm what number did you see? |
As a note, this is the patch we apply for group query attention support, if we can lower |
Thanks a lot for patch. We will try this patch and focus on enable llama3. |
Actually do you mind dropping the stories pte file? I realize there is something off and want to double check the performance number. |
oh also I have it combined with [the other matmul annotation[(https://github.com/pytorch/executorch/blob/faeeca8ec9040ae2db23973139c1b5f71ea51d4c/examples/qualcomm/llama2/llama.py#L59) as the cat annotation seems off |
Sorry about that it seems accuracy issue for story llama in fp and quantized mode when I rebuild the runner and backend lib in this PR. Performance:
Results:
|
Oops, I got it. It seems that tokenizer has some changes and I need to regenerate tokenize.bin.
Results:
|
Would you like to check 16a4w? |
Yeah |
I'm getting the following performance with your .pte file. Command line is
|
#4355 removes the |
I think there should be no impact as the final graph should be the same. I checked our ci and it doesn't show regression. |
That was just my thought when seeing their pr. After seeing the comment in the original PR, I think it should be good now. My current device is one plus 12 and ram is 16GB. What is your command to build llama runner? |
Got it.
|
Hey. Next, I will create three PRs to enable llama.
Do you have any concerns for this plan? Thanks a lot. |
I encountered this issue on the OnePlus 12 with SM8650 chip. I have updated the OS but the problem is still there. Do you @chiwwang have other suggestions? My QNN version is 2.22.6.240515. |
@leigao97 , if you saw exact
Then it's related to the system. It's hard to do anything on the application side. I can only suggest
|
Summary: - Fully delegate meta llama model in fp and quantized - Add simple calibration - Use custom fallback op to split graph - Add model sharding argument - Add splill fill feature Note that if you want to run llama 7b due to memory limitations on the device, you need to specify num_sharding. And it is recommended to reboot device before running to ensure that the device has enough memory.
But it will result in embedding op fallback. If change pos_ids to int32, it will be fully delegated.
- Support GQA, repeating kv caches - Support TicTokenizer for llm/export/builder.py - Support --embedding-quantize option for qualcomm lowering flow
d87d13d
to
e68e225
Compare
Hi @cccclai, @@ -25,6 +25,8 @@ install_executorch_and_backend_lib() {
-DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \
-DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \
-DEXECUTORCH_BUILD_XNNPACK=ON \
+ -DEXECUTORCH_BUILD_QNN=ON \
+ -DQNN_SDK_ROOT=$QNN_SDK_ROOT \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
@@ -39,18 +41,20 @@ build_llama_runner() {
ANDROID_ABI=arm64-v8a
cmake -DBUCK2="${BUCK2}" \
-DCMAKE_TOOLCHAIN_FILE="$ANDROID_NDK"/build/cmake/android.toolchain.cmake \
+ -DEXECUTORCH_USE_TIKTOKEN=ON \
-DANDROID_ABI="${ANDROID_ABI}" \
-DANDROID_PLATFORM=android-23 \
-DCMAKE_INSTALL_PREFIX=cmake-android-out \
-DCMAKE_BUILD_TYPE=Release -DPYTHON_EXECUTABLE=python \
-DEXECUTORCH_BUILD_XNNPACK=ON \
+ -DEXECUTORCH_BUILD_QNN=ON \
-DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \
-DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \
-Bcmake-android-out/examples/models/llama2 examples/models/llama2
cmake --build cmake-android-out/examples/models/llama2 -j4 --config Release About the exporting commands. here is an example for you, you can also add python -m examples.models.llama2.export_llama -t ${tokenizer.model} -p ${params.json} -c ${consolidated.00.bf16.pth} --use_kv_cache --qnn --disable_dynamic_shape --num_sharding 1 --pt2e_quantize qnn_16a4w Thank you |
@chiwwang Thank you for the help. Rooting the device works for me. |
Summary:
If change pos_ids to int32, it will be fully delegated.
There are still accuracy issues for llama 7b in 16a4w and more complicated quantization algorithms are needed.
Note that if you want to run llama 7b due to memory limitations on the device, you need to specify num_sharding.
And it is recommended to reboot device before running to ensure that the device has enough memory.
Install executorch and backend lib:
Build llama runner:
Export llama in qnn:
Local Results:

llama-7b-chat with 8 splits in 16a4w
story llama in 16a4w

story llama in 8a8w

story llama in fp16
