-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[performance_optim] define flash attention mask on NPU device directly #37698
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
327fd48 to
8eda2ef
Compare
|
@MekkCyber @SunMarc please help me review it, thanks : ) |
MekkCyber
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good ! thanks for catching that 🤗
SunMarc
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SGTM !
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
63af693 to
43736d4
Compare
|
@MekkCyber @SunMarc I have rebased the code to |
43736d4 to
0a9eaf7
Compare
|
@MekkCyber @SunMarc All CI passes and this PR seems to be ready for merge :) |
|
Merged 🎊 ! Thanks a lot |
huggingface#37698) Co-authored-by: Mohamed Mekkouri <[email protected]>
What does this PR do?
When using Flash Attention2 on Ascend NPU, we have found that CPU memory keep increasing when calling func
npu_flash_attn_varlen_funcornpu_flash_attn_func.The root cause is that the attention mask generated by func
torch.ones()is initially defined on the CPU side, occupying CPU memory before being transferred to the NPU device. As the funcnpu_flash_attn_varlen_funcornpu_flash_attn_funcis called repeatedly, the CPU memory consumption continues to accumulate, which is not optimal solution. Below is one example:transformers/src/transformers/integrations/npu_flash_attention.py
Line 225 in 12f65ee
Therefore, this PR is committed for solving this problem by defining attention mask tensor with
torch.ones()on NPU device directy.Fixes # (issue)
Not releated.
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.