Skip to content

Commit a8ffcaa

Browse files
authored
fix fa4 test (PaddlePaddle#6408)
1 parent 3ce842b commit a8ffcaa

File tree

3 files changed

+7
-2
lines changed

3 files changed

+7
-2
lines changed

fastdeploy/model_executor/layers/attention/flash_attn_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ def flash_attn_func(
118118
head_dim: int = 128,
119119
version: Optional[int] = None,
120120
):
121+
if FLASH_ATTN_VERSION is None:
122+
init_flash_attn_version()
121123
if version is None:
122-
if FLASH_ATTN_VERSION is None:
123-
init_flash_attn_version()
124124
version = FLASH_ATTN_VERSION
125125
if version == 4:
126126
assert (

tests/layers/test_flash_attn_func.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,9 @@ def test_fa4(self):
198198
k,
199199
v,
200200
attn_mask_q=attn_mask_q,
201+
num_heads=num_heads,
202+
kv_num_heads=kv_num_heads,
203+
head_dim=head_dim,
201204
version=4,
202205
)
203206

tests/operators/test_flash_mask_attn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,8 @@ def paddle_flash_attn_mask(self, q_input, k_input, v_input, attn_out, mask):
9191
)
9292

9393
def test_flash_mask_attention(self):
94+
if self.sm_version < 89 or self.sm_version >= 100:
95+
self.skipTest("flash_mask_attention V3 requires SM89+ but less than SM100.")
9496
q_input = paddle.randn([self.q_len, self.num_head * self.head_dim], dtype="bfloat16")
9597
k_input = paddle.randn([self.q_len + self.k_len, self.num_kv_head, self.head_dim], dtype="bfloat16")
9698
v_input = paddle.randn(k_input.shape, dtype="bfloat16")

0 commit comments

Comments
 (0)