Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 19 additions & 10 deletions flash_attn/cute/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,17 @@ def _get_device_capability():
"""Cached device capability check."""
return torch.cuda.get_device_capability()[0]


def _validate_head_dims(head_dim: int, head_dim_v: int, compute_capability: int) -> None:
is_standard_range = 8 <= head_dim <= 128 and 8 <= head_dim_v <= 128
is_extended_qk = 128 < head_dim <= 192 and 8 <= head_dim_v <= 128

assert is_standard_range or is_extended_qk, (
f"(head_dim, head_dim_v)=({head_dim}, {head_dim_v}) is not supported on SM{compute_capability * 10}. "
f"head_dim must be between 8 and 192, head_dim_v must be between 8 and 128."
)


def maybe_contiguous(x):
return x.contiguous() if x is not None and x.stride(-1) != 1 else x

Expand Down Expand Up @@ -211,11 +222,17 @@ def _flash_attn_fwd(
learnable_sink,
)
), "inputs must be on CUDA device"
compute_capability = (
_get_device_capability()
if _compute_capability is None
else _compute_capability
)
assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
_validate_head_dims(head_dim, head_dim_v, compute_capability)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
if softcap == 0.0:
Expand Down Expand Up @@ -247,14 +264,6 @@ def _flash_attn_fwd(
_validate_tensor(lse, "lse", lse_shape, torch.float32, device)

dtype = torch2cute_dtype_map[q.dtype]
compute_capability = (
_get_device_capability()
if _compute_capability is None
else _compute_capability
)

assert compute_capability in [9, 10, 11], "Unsupported compute capability. Supported: 9.x, 10.x, 11.x"

use_block_sparsity = block_sparse_tensors is not None

if mask_mod is None:
Expand Down Expand Up @@ -705,10 +714,10 @@ def _flash_attn_bwd(
t is None or t.is_cuda for t in (q, k, v, out, dout, lse, cu_seqlens_q, cu_seqlens_k)
), "inputs must be on CUDA device"
assert num_head % num_head_kv == 0, "num_head must be divisible by num_head_kv"
assert head_dim <= 256, "head_dim must be less than or equal to 256"
alignment = 16 // q.element_size()
assert head_dim % alignment == 0, f"head_dim must be divisible by {alignment}"
assert head_dim_v % alignment == 0, f"head_dim_v must be divisible by {alignment}"
_validate_head_dims(head_dim, head_dim_v, compute_capability)
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_dim)
qhead_per_kvhead = num_head // num_head_kv
Expand Down
15 changes: 15 additions & 0 deletions flash_attn/cute/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,21 @@ def parse_swizzle_from_pointer(ptr: cute.Pointer) -> cute.Swizzle:
raise ValueError(f"Could not parse swizzle_type: {swizzle_str}")


@dsl_user_op
def smid(*, loc=None, ip=None) -> Int32:
return Int32(
llvm.inline_asm(
T.i32(),
[],
"mov.u32 $0, %smid;",
"=r",
has_side_effects=False,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)


@dsl_user_op
def fmax(
a: float | Float32, b: float | Float32, c: float | Float32 | None = None, *, loc=None, ip=None
Expand Down
27 changes: 17 additions & 10 deletions tests/cute/test_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,8 @@
# @pytest.mark.parametrize("local_enum", [0])
@pytest.mark.parametrize("causal", [False, True])
# @pytest.mark.parametrize("causal", [False])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192, 256])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize("d", [64, 128, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128])
# @pytest.mark.parametrize("d", [64, 96, 128, 192])
# @pytest.mark.parametrize("d", [128, 192])
@pytest.mark.parametrize("d", [64, 128])
# @pytest.mark.parametrize("d", [128])
# @pytest.mark.parametrize("d", [32, 40, 56, 64, 80, 96, 128, 144, 160, 192])
@pytest.mark.parametrize(
"seqlen_q,seqlen_k",
[
Expand Down Expand Up @@ -119,7 +111,7 @@ def test_flash_attn_output(
nheads_kv = nheads if mha_type == "mha" else (3 if mha_type == "gqa" else 1)
dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype
# dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d])
dv_vals = [128] if d == 192 else ([d] if d != 128 else [64, d])
dv_vals = [128] if d > 128 else ([d] if d != 128 else [64, d])
if dtype == torch.float8_e4m3fn or TEST_BWD_ONLY:
dv_vals = [d]
# attention_chunk_vals = [torch.randint(1, seqlen_k * 2, (1,)).item(), 0]
Expand Down Expand Up @@ -1521,3 +1513,18 @@ def test_flash_attn_combine(num_splits, seqlen, d, dtype):
assert torch.allclose(out_no_lse, out, atol=1e-5, rtol=1e-5), (
"Output should be the same regardless of return_lse"
)


@pytest.mark.parametrize("head_dim", [4, 144, 256])
def test_flash_attn_invalid_head_dim(head_dim):
"""Verify that invalid head dimensions raise AssertionError."""
device = "cuda"
dtype = torch.bfloat16
batch_size, seqlen, nheads = 1, 64, 4

q = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
k = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)
v = torch.randn(batch_size, seqlen, nheads, head_dim, device=device, dtype=dtype)

with pytest.raises(AssertionError):
flash_attn_func(q, k, v)