Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions src/transformers/modeling_flash_attention_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
from flash_attn.bert_padding import unpad_input as unpad_input_fa2
from flash_attn.layers.rotary import apply_rotary_emb

HAS_FA2 = True
FA_VERSION = 2
elif is_torch_npu_available():
# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
from .integrations.npu_flash_attention import npu_apply_rotary_emb as apply_rotary_emb # noqa: F401
from .integrations.npu_flash_attention import npu_flash_attn_func as flash_attn_2_func
from .integrations.npu_flash_attention import npu_flash_attn_varlen_func as flash_attn_2_varlen_func
from .integrations.npu_flash_attention import pad_input as pad_input_fa2
from .integrations.npu_flash_attention import unpad_input as unpad_input_fa2

HAS_FA2 = True
FA_VERSION = 2
else:
Expand Down Expand Up @@ -136,22 +146,6 @@ def _fa3_pad_input(hidden_states, indices, batch, seqlen):
unpad_input = globals()[f"unpad_input_fa{FA_VERSION}"]
pad_input = globals()[f"pad_input_fa{FA_VERSION}"]

# patch functions in package `flash-attn` when using flash-attention on Ascend NPU.
if is_torch_npu_available():
from .integrations.npu_flash_attention import (
npu_apply_rotary_emb as apply_rotary_emb, # noqa: F401
)
from .integrations.npu_flash_attention import (
npu_flash_attn_func as flash_attn_func,
)
from .integrations.npu_flash_attention import (
npu_flash_attn_varlen_func as flash_attn_varlen_func,
)
from .integrations.npu_flash_attention import (
pad_input,
unpad_input,
)


_flash_supports_window_size = False

Expand Down