Skip to content

Commit 6023ca8

Browse files
FightingZhenCyrilvallez
authored andcommitted
[bugfix] fix flash attention 2 unavailable error on Ascend NPU (#39166)
[bugfix] fix flash attention 2 error on Ascend NPU
1 parent e7e78b2 commit 6023ca8

File tree

1 file changed

+10
-16
lines changed

1 file changed

+10
-16
lines changed

src/transformers/modeling_flash_attention_utils.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,16 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
103103
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
104104
from flash_attn.layers.rotary import apply_rotary_emb
105105

106+
HAS_FA2 = True
107+
FA_VERSION = 2
108+
elif is_torch_npu_available():
109+
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
110+
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
111+
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
112+
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
113+
from .integrations.npu_flash_attention import pad_input as pad_input_fa2
114+
from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2
115+
106116
HAS_FA2 = True
107117
FA_VERSION = 2
108118
else:
@@ -136,22 +146,6 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
136146
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
137147
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]
138148

139-
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
140-
if is_torch_npu_available():
141-
from .integrations.npu_flash_attention import (
142-
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
143-
)
144-
from .integrations.npu_flash_attention import (
145-
npu_flash_attn_func as flash_attn_func,
146-
)
147-
from .integrations.npu_flash_attention import (
148-
npu_flash_attn_varlen_func as flash_attn_varlen_func,
149-
)
150-
from .integrations.npu_flash_attention import (
151-
pad_input,
152-
unpad_input,
153-
)
154-
155149

156150
_flash_supports_window_size = False
157151

0 commit comments

Comments
 (0)