-
Notifications
You must be signed in to change notification settings - Fork 536
Qualcomm AI Engine Direct - Add llama sha transforming pass #6211
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Qualcomm AI Engine Direct - Add llama sha transforming pass #6211
Conversation
chunit-quic
commented
Oct 15, 2024
- Add SHA pass
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6211
Note: Links to docs will display an error until the docs builds have been completed. ✅ No FailuresAs of commit 607bc6c with merge base 5b51bb8 ( This comment was automatically generated by Dr. CI and updates every 15 minutes. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
|
||
import torch | ||
|
||
from .export_llama_lib import build_args_parser, export_llama | ||
|
||
sys.setrecursionlimit(4096) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what is this for?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We hit the maximum recursion depth during model = prepare(model, node_name_to_scope, is_qat=False)
in builder.py. Therefore we enlarge the limit here.
@@ -260,21 +260,22 @@ class Attention(nn.Module): | |||
def __init__(self, args: ModelArgs, layer_id: int): | |||
super().__init__() | |||
self.use_kv_cache = args.use_kv_cache | |||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads | |||
assert args.n_heads % self.n_kv_heads == 0 | |||
self.n_heads = args.n_heads |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it's still draft, but will be good to have a seperate PR if we need to change llama_transformer.py
..
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. We can separate this part to another PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @cccclai
Just pushed a seperated PR for it. If PR 6376 looks fine to you and is merged, I will rebase this PR. Thank you. :)
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
4cd737a
to
60aeec8
Compare
60aeec8
to
a2a97a7
Compare
@cccclai |
Sure yeah, let me merge it |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
Hi this PR breaks some internal test, I need to get you a patch to land this PR safely. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good
|
||
import torch | ||
|
||
from .export_llama_lib import build_args_parser, export_llama | ||
|
||
sys.setrecursionlimit(4096) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it still required in the latest commit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes it is needed when enable use_qnn_sha. Otherwise will trigger maximum recursion depth at the prepare_pt2e funciton.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a comment and explain the reason? Also how likely we can guard it to args.qnn only?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for late reply. I was running an experiment about the comment.
I believe we can move this line and import sys to a qualcomm specific condition
@ def get_quantizer_and_quant_params(args):
@@ -557,6 +557,8 @
quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
quant_dtype = None
if args.qnn and args.pt2e_quantize:
+ import sys
+ sys.setrecursionlimit(4096)
It can guard it to args.qnn only. If this one looks better, I will raise a PR to move it and add comment
Can you apply this patch to fix internal test?
|
a2a97a7
to
607bc6c
Compare
Sure, just rebase and add. |
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |
1 similar comment
@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator. |