@@ -103,6 +103,16 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
103
103
from flash_attn .bert_padding import unpad_input as unpad_input_fa2
104
104
from flash_attn .layers .rotary import apply_rotary_emb
105
105
106
+ HAS_FA2 = True
107
+ FA_VERSION = 2
108
+ elif is_torch_npu_available ():
109
+ # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
110
+ from .integrations .npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
111
+ from .integrations .npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
112
+ from .integrations .npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
113
+ from .integrations .npu_flash_attention import pad_input as pad_input_fa2
114
+ from .integrations .npu_flash_attention import unpad_input as unpad_input_fa2
115
+
106
116
HAS_FA2 = True
107
117
FA_VERSION = 2
108
118
else :
@@ -136,22 +146,6 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
136
146
unpad_input = globals ()[f"unpad_input_fa{ FA_VERSION } " ]
137
147
pad_input = globals ()[f"pad_input_fa{ FA_VERSION } " ]
138
148
139
- # patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
140
- if is_torch_npu_available ():
141
- from .integrations .npu_flash_attention import (
142
- npu_apply_rotary_emb as apply_rotary_emb , # noqa: F401
143
- )
144
- from .integrations .npu_flash_attention import (
145
- npu_flash_attn_func as flash_attn_func ,
146
- )
147
- from .integrations .npu_flash_attention import (
148
- npu_flash_attn_varlen_func as flash_attn_varlen_func ,
149
- )
150
- from .integrations .npu_flash_attention import (
151
- pad_input ,
152
- unpad_input ,
153
- )
154
-
155
149
156
150
_flash_supports_window_size = False
157
151
0 commit comments