-
Notifications
You must be signed in to change notification settings - Fork 536
Qualcomm AI Engine Direct - Enable AR-N model for prompt processing in hybrid mode #8210
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/8210
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 7d9a14e with merge base 77589c6 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
fc66256
to
d36b867
Compare
This PR enables the AR-N model for prompt processing in hybrid mode. Regarding the change in lower_module_backend.py, it is intended to prevent a double deletion of the persistent buffer. I observed that the buffer (freq_cos and freq_sin) is copied to each delegate node (due to graph sharding), and the original buffer is eventually deleted. Since each copied buffer shares the same target, this would result in a double deletion. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Awesome! Thank you for getting it to work so quickly. Can you help fix these errors?
|
KV Cache Mode: In KV Cache mode, the model takes in a single previous token and generates the next predicted token along with its KV cache. It is efficient for generating subsequent tokens after the initial prompt. | ||
|
||
Hybrid Mode: Hybrid mode leverages the strengths of both batch prefill and KV cache modes to optimize token generation speed. Initially, it uses prefill mode to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. | ||
Hybrid Mode: Hybrid mode leverages the strengths of both AR-N model and KV cache modes to optimize token generation speed. Initially, it uses AR-N model to efficiently generate the prompt's key-value (KV) cache. Then, the mode switches to KV cache mode, which excels at generating subsequent tokens. | ||
- AR-N model: The auto-regression (AR) length determines the number of tokens to consume and the number of logits to produce. Use it to process the prompt and generate the key-value (kv) cache, which serves as a prompt processor in hybrid mode. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we add the diagram you shared as part of the readme? It's much easier to understand with it.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Certainly, that will not be a problem.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! And the lint
2d5fa26
to
203b87b
Compare
Woops, Thanks for your effort. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hey it seems like some merge conflict, can you rebase? |
5dc5c17
to
0722beb
Compare
mode Summary: - Add `max_seq_len` to refer to maximum number of tokens that the model can process & consider at once to generate predictions/responses. - Add `prefill_ar_n` to determine the number of tokens to consume and the number of logits to produce for prompt processor in hybrid mode. - Remove prefill mode
0722beb
to
7d9a14e
Compare
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
// If the cache length is zero, it indicates a BERT model, which does not use | ||
// position ids or KV cache inputs. | ||
const bool is_bert_{false}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why waste a byte storing this separately rather than a private method like so?
bool is_bert() const {
return prefill_cache_len_ == 0;
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for mentioning that. We will include it in the upcoming PR.
Summary:
--max_seq_len
to refer to maximum number of tokens that the model can process & consider at once to generate predictions/responses.--prefill_ar_n
to determine the number of tokens to consume and the number of logits to produce for prompt processor in hybrid mode.Test Plan
Try to find the best AR-N with CL=2048 in hybrid mode
Ensure accuracy for Stories Llama 16a4w, Prompt: "Once upon a time"
Command