Skip to content

Commit 4dd18b1

Browse files
committed
Nicer headdim error message
stack-info: PR: #2227, branch: drisspg/stack/9
1 parent d39b629 commit 4dd18b1

File tree

2 files changed

+41
-10
lines changed

2 files changed

+41
-10
lines changed

flash_attn/cute/interface.py

Lines changed: 26 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,24 @@ def _get_device_capability():
6363
"""Cached device capability check."""
6464
return torch.cuda.get_device_capability()[0]
6565

66+
67+
def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int) -> None:
68+
"""Validate head dimension constraints based on compute capability."""
69+
is_deepseek_shape = head_dim == 192 and head_dim_v == 128
70+
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
71+
72+
if compute_capability == 9:
73+
assert is_standard_range, (
74+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM90. "
75+
f"head_dim and head_dim_v must be between 8 and 128."
76+
)
77+
elif compute_capability in [10, 11]:
78+
assert is_standard_range or is_deepseek_shape, (
79+
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM100/SM110. "
80+
f"head_dim and head_dim_v must be between 8 and 128, or (192, 128) for DeepSeek."
81+
)
82+
83+
6684
def maybe_contiguous(x):
6785
return x.contiguous() if x is not None and x.stride(-1) != 1 else x
6886

@@ -211,11 +229,17 @@ def _flash_attn_fwd(
211229
learnable_sink,
212230
)
213231
), "inputs must be on CUDA device"
232+
compute_capability = (
233+
_get_device_capability()
234+
if _compute_capability is None
235+
else _compute_capability
236+
)
237+
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
214238
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
215-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
216239
alignment = 16 // q.element_size()
217240
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
218241
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
242+
_validate_head_dims(head_dim, head_dim_v, compute_capability)
219243
if softmax_scale is None:
220244
softmax_scale = 1.0 / math.sqrt(head_dim)
221245
if softcap == 0.0:
@@ -247,14 +271,6 @@ def _flash_attn_fwd(
247271
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)
248272

249273
dtype = torch2cute_dtype_map[q.dtype]
250-
compute_capability = (
251-
_get_device_capability()
252-
if _compute_capability is None
253-
else _compute_capability
254-
)
255-
256-
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
257-
258274
use_block_sparsity = block_sparse_tensors is not None
259275

260276
if mask_mod is None:
@@ -698,10 +714,10 @@ def _flash_attn_bwd(
698714
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
699715
), "inputs must be on CUDA device"
700716
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
701-
assert head_dim <= 256, "head_dim must be less than or equal to 256"
702717
alignment = 16 // q.element_size()
703718
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
704719
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
720+
_validate_head_dims(head_dim, head_dim_v, compute_capability)
705721
if softmax_scale is None:
706722
softmax_scale = 1.0 / math.sqrt(head_dim)
707723
qhead_per_kvhead = num_head // num_head_kv

tests/cute/test_flash_attn.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1520,3 +1520,18 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
15201520
assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), (
15211521
"Output should be the same regardless of return_lse"
15221522
)
1523+
1524+
1525+
@pytest.mark.parametrize("head_dim", [4, 144, 256])
1526+
def test_flash_attn_invalid_head_dim(head_dim):
1527+
"""Verify that invalid head dimensions raise AssertionError."""
1528+
device = "cuda"
1529+
dtype = torch.bfloat16
1530+
batch_size, seqlen, nheads = 1, 64, 4
1531+
1532+
q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1533+
k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1534+
v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
1535+
1536+
with pytest.raises(AssertionError):
1537+
flash_attn_func(q, k, v)

0 commit comments

Comments
 (0)