Skip to content

Commit 650af93

Browse files
committed
Fix Hopper tests
1 parent 37ffcbb commit 650af93

File tree

4 files changed

+18
-1
lines changed

4 files changed

+18
-1
lines changed

flash_attn/cute/interface.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -629,6 +629,13 @@ def _flash_attn_bwd(
629629
AtomLayoutMdQ = 1
630630
cluster_size = 1
631631
assert window_size_left is None and window_size_right is None, "local not supported yet on 9.x"
632+
is_varlen = (
633+
cu_seqlens_q is not None
634+
or cu_seqlens_k is not None
635+
or seqused_q is not None
636+
or seqused_k is not None
637+
)
638+
assert not is_varlen, "varlen backward is not yet supported on sm90"
632639
else:
633640
m_block_size = 128
634641
n_block_size = 128

tests/cute/test_flash_attn.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
709709
and not attention_chunk != 0
710710
and dv == d
711711
and not has_learnable_sink
712+
and not IS_SM90
712713
# and False
713714
):
714715
g_unpad = torch.randn_like(out_unpad)

tests/cute/test_flash_attn_race_condition.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
flash_attn_varlen_func,
2727
flash_attn_combine,
2828
_flash_attn_bwd,
29+
_get_device_capability,
2930
)
3031

3132

@@ -407,6 +408,11 @@ def test_flash_attn_varlen_output(
407408
local = local_enum > 0
408409
if local and causal:
409410
pytest.skip()
411+
is_sm90 = _get_device_capability() == 9
412+
if is_sm90 and local:
413+
pytest.xfail("bwd local attention not supported on sm90")
414+
if is_sm90 and deterministic:
415+
pytest.xfail("bwd deterministic not supported on sm90")
410416
if (
411417
causal or local
412418
): # Right now reference only supports causal attention with seqlen_k == seqlen_q
@@ -645,6 +651,7 @@ def _gen_unused_masks(padding_mask, add_unused, max_seq_len, bs, device):
645651
and not attention_chunk != 0
646652
and dv == d
647653
and not has_learnable_sink
654+
and not is_sm90
648655
# and False
649656
):
650657
g_unpad = torch.randn_like(out_unpad)

tests/cute/test_mask_mod.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
277277

278278
# SM90 block-sparse backward expects BlockMask granularity (128, 128) regardless of fwd tiling.
279279
sparse_tile_m_bwd = sparse_tile_m
280+
tile_n_bwd = tile_n
280281
if COMPUTE_CAPABILITY == 9 and use_block_sparsity and (sparse_tile_m, tile_n) != (128, 128):
281282
bm_bwd = create_block_mask(
282283
mask_mod_flex,
@@ -301,6 +302,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
301302
*_,
302303
) = bm_bwd.as_tuple()
303304
sparse_tile_m_bwd = 128
305+
tile_n_bwd = 128
304306

305307
softmax_scale = 1.0 / math.sqrt(headdim)
306308

@@ -323,7 +325,7 @@ def mask_mod_flex(b, h, q_idx, kv_idx, bias=bias):
323325
mask_block_idx=q_mask_idx,
324326
full_block_cnt=full_q_cnt,
325327
full_block_idx=full_q_idx,
326-
block_size=(sparse_tile_m_bwd, tile_n),
328+
block_size=(sparse_tile_m_bwd, tile_n_bwd),
327329
)
328330
if use_block_sparsity
329331
else None

0 commit comments

Comments
 (0)