Skip to content

Commit affcd3c

Browse files
committed
Nicer headdim error message
stack-info: PR: #2227, branch: drisspg/stack/9
1 parent 40779bc commit affcd3c

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

flash_attn/cute/interface.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,24 @@ def _get_device_capability():
6868
def _get_use_clc_scheduler_default():
6969
return os.environ.get("FA4_CLC", "0") == "1"
7070

71+
72+
def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int) -> None:
73+
"""Validate head dimension constraints based on compute capability."""
74+
is_deepseek_shape = head_dim == 192 and head_dim_v == 128
75+
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
76+
77+
if compute_capability == 9:
78+
assert is_standard_range, (
79+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
80+
f"head_dim and head_dim_v must be between 8 and 128."
81+
)
82+
elif compute_capability in [10, 11]:
83+
assert is_standard_range or is_deepseek_shape, (
84+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. "
85+
f"head_dim and head_dim_v must be between 8 and 128, or (192, 128) for DeepSeek."
86+
)
87+
88+
7189
def maybe_contiguous(x):
7290
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
7391

@@ -218,11 +236,17 @@ def _flash_attn_fwd(
218236
learnable_sink,
219237
)
220238
), "inputs must be on CUDA device"
239+
compute_capability = (
240+
_get_device_capability()
241+
if _compute_capability is None
242+
else _compute_capability
243+
)
244+
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
221245
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
222-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
223246
alignment = 16 // q.element_size()
224247
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
225248
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
249+
_validate_head_dims(head_dim, head_dim_v, compute_capability)
226250
if softmax_scale is None:
227251
softmax_scale = 1.0 / math.sqrt(head_dim)
228252
if softcap == 0.0:
@@ -254,14 +278,6 @@ def _flash_attn_fwd(
254278
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
255279

256280
dtype = torch2cute_dtype_map[q.dtype]
257-
compute_capability = (
258-
_get_device_capability()
259-
if _compute_capability is None
260-
else _compute_capability
261-
)
262-
263-
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
264-
265281
use_block_sparsity = block_sparse_tensors is not None
266282

267283
if mask_mod is None:
@@ -712,10 +728,10 @@ def _flash_attn_bwd(
712728
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
713729
), "inputs must be on CUDA device"
714730
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
715-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
716731
alignment = 16 // q.element_size()
717732
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
718733
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
734+
_validate_head_dims(head_dim, head_dim_v, compute_capability)
719735
if softmax_scale is None:
720736
softmax_scale = 1.0 / math.sqrt(head_dim)
721737
qhead_per_kvhead = num_head // num_head_kv

tests/cute/test_flash_attn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,3 +1520,18 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
15201520
assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), (
15211521
"Output should be the same regardless of return_lse"
15221522
)
1523+
1524+
1525+
@pytest.mark.parametrize("head_dim", [4, 144, 256])
1526+
def test_flash_attn_invalid_head_dim(head_dim):
1527+
"""Verify that invalid head dimensions raise AssertionError."""
1528+
device = "cuda"
1529+
dtype = torch.bfloat16
1530+
batch_size, seqlen, nheads = 1, 64, 4
1531+
1532+
q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1533+
k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1534+
v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1535+
1536+
with pytest.raises(AssertionError):
1537+
flash_attn_func(q, k, v)

0 commit comments

Comments
 (0)