@@ -103,6 +103,16 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
103103 from flash_attn .bert_padding import unpad_input as unpad_input_fa2
104104 from flash_attn .layers .rotary import apply_rotary_emb
105105
106+ HAS_FA2 = True
107+ FA_VERSION = 2
108+ elif is_torch_npu_available ():
109+ # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
110+ from .integrations .npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
111+ from .integrations .npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
112+ from .integrations .npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
113+ from .integrations .npu_flash_attention import pad_input as pad_input_fa2
114+ from .integrations .npu_flash_attention import unpad_input as unpad_input_fa2
115+
106116 HAS_FA2 = True
107117 FA_VERSION = 2
108118else :
@@ -136,22 +146,6 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
136146 unpad_input = globals ()[f"unpad_input_fa{ FA_VERSION } " ]
137147 pad_input = globals ()[f"pad_input_fa{ FA_VERSION } " ]
138148
139- # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
140- if is_torch_npu_available ():
141- from .integrations .npu_flash_attention import (
142- npu_apply_rotary_emb as apply_rotary_emb , # noqa: F401
143- )
144- from .integrations .npu_flash_attention import (
145- npu_flash_attn_func as flash_attn_func ,
146- )
147- from .integrations .npu_flash_attention import (
148- npu_flash_attn_varlen_func as flash_attn_varlen_func ,
149- )
150- from .integrations .npu_flash_attention import (
151- pad_input ,
152- unpad_input ,
153- )
154-
155149
156150_flash_supports_window_size = False
157151
0 commit comments