Skip to content

Qualcomm AI Engine Direct - Optimize the performance for AR-N model #9079

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 13, 2025

Conversation

shewu-quic
Copy link
Collaborator

@shewu-quic shewu-quic commented Mar 10, 2025

Summary:

  • Fix the bug of rms norm builder
  • Use HuggingFace version RoPE to improve the performance due to stride = 1 in StrideSlice Op
  • Modificate the axis order of the conv in qkv, feedforward and output
    • Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048)
    • New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048)

Test Result:

  • Verify the output for story llama with smart mask, CL=128, prefill_ar_n=16, prompt="Once"
    Note that using Hugging Face RoPE will slightly affect accuracy
    • Original (mainline)
INFO:root:Results[0]:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and her favorite toy was a big, red ball. One day, Lily's mom asked her to help her with the laundry. Lily was happy to help and she put all the clothes in the washing machine. 
After the clothes were washed, Lily's mom asked her to help her hang them up to dry. Lily saw a big, black rake and asked her mom what it was. Her mom told her it was a rake and that it helps to
  • Optimized (this PR)
INFO:root:Results[0]:
Once upon a time, there was a little girl named Lily. She loved to play with her toys and her favorite toy was a big, red ball. One day, Lily's mom asked her to help her with the laundry. Lily was happy to help and she put all the clothes in the washing machine. 
After the clothes were washed, Lily's mom asked her to help her hang them up to dry. Lily saw a big, black iron on the counter and asked her mom what it was for. Her mom explained that it was used to make clothes smooth
  • Verify the performance for llama 3.2 1B with shift pointer, CL=2048, prefill_ar_n=256
    • Original (mainline)
I 00:00:02.048851 executorch:runner.cpp:354] Prompt Processor: total 256 tokens (AR-256 * 1 iters)
I 00:00:36.606984 executorch:runner.cpp:456] 	Prompt Tokens: 256    Generated Tokens: 1791
I 00:00:36.607049 executorch:runner.cpp:462] 	Model Load Time:		2.012000 (seconds)
I 00:00:36.607062 executorch:runner.cpp:472] 	Total inference time:		34.592000 (seconds)		 Rate: 	51.774977 (tokens/second)
I 00:00:36.607072 executorch:runner.cpp:480] 		Prompt evaluation:	0.293000 (seconds)		 Rate: 	873.720137 (tokens/second)
I 00:00:36.607080 executorch:runner.cpp:491] 		Generated 1791 tokens:	34.299000 (seconds)		 Rate: 	52.217266 (tokens/second)
I 00:00:36.607089 executorch:runner.cpp:499] 	Time to first generated token:	0.293000 (seconds)
I 00:00:36.607099 executorch:runner.cpp:506] 	Sampling time over 1791 tokens:	1.473000 (seconds)
  • Optimized (this PR)
I 00:00:01.827440 executorch:runner.cpp:354] Prompt Processor: total 256 tokens (AR-256 * 1 iters)
I 00:00:03.143673 executorch:runner.cpp:456] 	Prompt Tokens: 256    Generated Tokens: 64
I 00:00:03.143686 executorch:runner.cpp:462] 	Model Load Time:		1.791000 (seconds)
I 00:00:03.143698 executorch:runner.cpp:472] 	Total inference time:		1.350000 (seconds)		 Rate: 	47.407407 (tokens/second)
I 00:00:03.143706 executorch:runner.cpp:480] 		Prompt evaluation:	0.126000 (seconds)		 Rate: 	2031.746032 (tokens/second)
I 00:00:03.143715 executorch:runner.cpp:491] 		Generated 64 tokens:	1.224000 (seconds)		 Rate: 	52.287582 (tokens/second)
I 00:00:03.143723 executorch:runner.cpp:499] 	Time to first generated token:	0.126000 (seconds)
I 00:00:03.143733 executorch:runner.cpp:506] 	Sampling time over 64 tokens:	0.058000 (seconds)

@shewu-quic shewu-quic requested a review from cccclai as a code owner March 10, 2025 06:05
Copy link

pytorch-bot bot commented Mar 10, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/9079

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure, 1 Unrelated Failure

As of commit c94c0bd with merge base acae017 (image):

NEW FAILURE - The following job has failed:

BROKEN TRUNK - The following job failed but were present on the merge base:

👉 Rebase onto the `viable/strict` branch to avoid these failures

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Mar 10, 2025
@shewu-quic
Copy link
Collaborator Author

shewu-quic commented Mar 10, 2025

Hi @cccclai,
This PR aims to enhance the performance of the prompt processor (AR-N model).
The main improvement comes from using a stride slice operation with a stride of 1, which offers better performance. Consequently, we’ve switched to the Huggingface version of RoPE.
Could you help to take a look?

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@@ -16,8 +17,9 @@ class RecomposeRmsNorm(ExportPass):
Merge decomposed operators back to one super node.
"""

def __init__(self):
super().__init__()
def __init__(self, edge_program: torch.export.ExportedProgram):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can follow #8505 to get rid of some recompose logic to reduce engineer effort there

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your information. I will try it.

@@ -19,12 +19,12 @@
def apply_rotary_emb_single(
x: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
) -> torch.Tensor:
x_r, x_i = x[..., ::2], x[..., 1::2]

# Change to RoPE of huggingface version
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which one is the huggingface version and why is it better?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The implementation of RoPE in huggingface process query and key with two half instead of interleaved way.
The main difference is stride in StrideSlice op. For interleaved way, stride is two which is not friendly for HTP backend about this memory handle.
Ref: huggingface/transformers#25199

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add this comment to part of the code comment, just so others know the context.

Copy link
Contributor

@cccclai cccclai left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The perf improvement looks awesome!!

@cccclai
Copy link
Contributor

cccclai commented Mar 12, 2025

There is still lint error, can you fix it?

@shewu-quic shewu-quic force-pushed the dev1/hutton/optimize_arn_model branch from b4d5a63 to fde0f80 Compare March 12, 2025 03:51
@shewu-quic shewu-quic requested a review from SS-JIA as a code owner March 12, 2025 03:51
@shewu-quic
Copy link
Collaborator Author

There is still lint error, can you fix it?

Done. Thanks :)

@cccclai cccclai added the release notes: qualcomm Changes to the Qualcomm backend delegate label Mar 12, 2025
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Mar 12, 2025

This seems need rebase

    Summary:
    - Fix the bug of rms norm builder
    - Use HuggingFace version RoPE to improve the performance due to
      stride = 1 in StrideSlice Op
    - Modificate the axis order of the conv in qkv, feedforward and output
      - Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048)
      - New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048)
@shewu-quic shewu-quic force-pushed the dev1/hutton/optimize_arn_model branch from fde0f80 to c5c149c Compare March 12, 2025 23:12
@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai
Copy link
Contributor

cccclai commented Mar 13, 2025

Can you also share which part of the logic do the following optmization?

Modificate the axis order of the conv in qkv, feedforward and output
Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048)
New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048)

Is it the weight permutation or something else? Also good to have it as part of the code comment so it's easy to understand the intention.

@shewu-quic
Copy link
Collaborator Author

Can you also share which part of the logic do the following optmization?

Modificate the axis order of the conv in qkv, feedforward and output
Original (AR:128, CL:2048): QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,2048,1)->QNN_Transpose (1,128,1,2048)->self.output-> QNN_Transpose(1,128,2048,1) -> QNN_Reshape (1,1,128,2048)
New: QNN_RmsNorm (1,1,128,2048) -> QNN_Reshape (1,128,1,2048)->QNN_Transpose (1,1,128,2048)->self.output-> QNN_Transpose(1,128,1,2048) -> QNN_Reshape (1,1,128,2048)

Is it the weight permutation or something else? Also good to have it as part of the code comment so it's easy to understand the intention.

Got it. But this optimization is not quite general. Based on our experiments, the performance which sets sequence length to width dimension (1, 1, seq_len, CL) is better than the performance which sets sequence length to height dimension (1, seq_len, 1, CL) for the input axis order of the conv op. And another reason is that this change will be close with the structure of AI Hub version llama

@facebook-github-bot
Copy link
Contributor

@cccclai has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

@cccclai cccclai merged commit baf35d2 into pytorch:main Mar 13, 2025
49 of 52 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. release notes: qualcomm Changes to the Qualcomm backend delegate
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants