Skip to content

Qualcomm AI Engine Direct - Add llama sha transforming pass #6211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

chunit-quic
Copy link
Collaborator

  • Add SHA pass

Copy link

pytorch-bot bot commented Oct 15, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/6211

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 607bc6c with merge base 5b51bb8 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Oct 15, 2024
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.


import torch

from .export_llama_lib import build_args_parser, export_llama

sys.setrecursionlimit(4096)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We hit the maximum recursion depth during model = prepare(model, node_name_to_scope, is_qat=False) in builder.py. Therefore we enlarge the limit here.

@@ -260,21 +260,22 @@ class Attention(nn.Module):
def __init__(self, args: ModelArgs, layer_id: int):
super().__init__()
self.use_kv_cache = args.use_kv_cache
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
assert args.n_heads % self.n_kv_heads == 0
self.n_heads = args.n_heads
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it's still draft, but will be good to have a seperate PR if we need to change llama_transformer.py..

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. We can separate this part to another PR.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @cccclai
Just pushed a seperated PR for it. If PR 6376 looks fine to you and is merged, I will rebase this PR. Thank you. :)

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@chunit-quic chunit-quic force-pushed the dev1/chunit/add_llama_sha_pass branch from 4cd737a to 60aeec8 Compare November 1, 2024 04:10
@chunit-quic chunit-quic marked this pull request as ready for review November 6, 2024 03:17
@chunit-quic chunit-quic force-pushed the dev1/chunit/add_llama_sha_pass branch from 60aeec8 to a2a97a7 Compare November 6, 2024 03:22
@chunit-quic
Copy link
Collaborator Author

@cccclai
Just a gentle ping. We rebased and changed to PR from draft few minutes ago. Would this be fine to import and merge? Thank you!

@cccclai
Copy link
Contributor

cccclai commented Nov 6, 2024

Sure yeah, let me merge it

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Nov 7, 2024

Hi this PR breaks some internal test, I need to get you a patch to land this PR safely.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good


import torch

from .export_llama_lib import build_args_parser, export_llama

sys.setrecursionlimit(4096)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it still required in the latest commit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it is needed when enable use_qnn_sha. Otherwise will trigger maximum recursion depth at the prepare_pt2e funciton.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a comment and explain the reason? Also how likely we can guard it to args.qnn only?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for late reply. I was running an experiment about the comment.
I believe we can move this line and import sys to a qualcomm specific condition

@ def get_quantizer_and_quant_params(args):
@@ -557,6 +557,8 @
     quantizers = get_pt2e_quantizers(pt2e_quant_params, args.so_library)
     quant_dtype = None
     if args.qnn and args.pt2e_quantize:
+        import sys
+        sys.setrecursionlimit(4096)

It can guard it to args.qnn only. If this one looks better, I will raise a PR to move it and add comment

@cccclai
Copy link
Contributor

cccclai commented Nov 11, 2024

Can you apply this patch to fix internal test?

--- a/executorch/examples/models/llama/TARGETS
+++ b/executorch/examples/models/llama/TARGETS
@@ -82,6 +82,7 @@
         "export_llama_lib.py",
         "model.py",
         "source_transformation/apply_spin_quant_r1_r2.py",
+        "source_transformation/attention.py",
         "source_transformation/lora.py",
         "source_transformation/pre_quantization.py",
         "source_transformation/prune_vocab.py",

@chunit-quic chunit-quic force-pushed the dev1/chunit/add_llama_sha_pass branch from a2a97a7 to 607bc6c Compare November 11, 2024 04:05
@chunit-quic
Copy link
Collaborator Author

Can you apply this patch to fix internal test?

Sure, just rebase and add.

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

1 similar comment
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@facebook-github-bot facebook-github-bot merged commit 576e96c into pytorch:main Nov 11, 2024
40 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants