Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 115 additions & 11 deletions flashinfer/mla/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,21 @@ def _check_trtllm_gen_mla_shape(
page_table: torch.Tensor,
page_size: int,
uses_shared_paged_kv_idx: bool = True,
batch_size: Optional[int] = None,
max_q_len: Optional[int] = None,
) -> torch.Tensor:
if query.ndim != 4:
raise ValueError(f"Expected query.ndim == 4, got {query.ndim}")
if query.ndim == 4:
num_seqs, num_tokens, _, qk_head_dim = query.shape
elif query.ndim == 3:
if batch_size is None or max_q_len is None:
raise ValueError(
"batch_size and max_q_len are required when query.ndim == 3"
)
num_seqs = batch_size
num_tokens = max_q_len
_, _, qk_head_dim = query.shape
else:
raise ValueError(f"Expected query.ndim == 3 or 4, got {query.ndim}")

# Support both 3D and 4D kv_cache for backward compatibility
if kv_cache.ndim == 3:
Expand All @@ -166,7 +178,6 @@ def _check_trtllm_gen_mla_shape(
f"Unsupported MLA dimensions, got kv_lora_rank={kv_lora_rank} and qk_rope_head_dim={qk_rope_head_dim}, supported dimensions are: {supported_mla_head_dimensions}"
)

num_seqs, num_tokens, _, qk_head_dim = query.shape
ckv_dim = kv_cache.shape[3]
expected_qk_head_dim = kv_lora_rank + qk_rope_head_dim
if qk_head_dim != expected_qk_head_dim or ckv_dim != expected_qk_head_dim:
Expand Down Expand Up @@ -615,6 +626,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
backend: str = "auto",
is_var_seq: bool = True,
uses_shared_paged_kv_idx: bool = True,
cum_seq_lens_q: Optional[torch.Tensor] = None,
max_q_len: Optional[int] = None,
) -> torch.Tensor:
"""
Parameters
Expand All @@ -626,13 +639,14 @@ def trtllm_batch_decode_with_kv_cache_mla(
kv_lora_rank: kv_lora_rank, must be 512 or 256
qk_rope_head_dim: qk_rope_head_dim, must be 64
sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA.
Not supported together with ``cum_seq_lens_q``.
block_tables: page table of kv cache.
When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``.
When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]``
where dim 1 distinguishes K (0) and V (1) page indices. For MLA both rows will
typically be identical since K and V share the same compressed representation.
seq_lens: query_len
max_seq_len: max sequence length for kv_cache
seq_lens: per-request KV sequence lengths
max_seq_len: max KV sequence length for kv_cache
out: output tensor, if not provided, will be allocated internally
bmm1_scale: fused scale for mla bmm1 input.
When using ``trtllm-gen`` backend, it can be a ``torch.Tensor`` with dtype ``torch.float32``.
Expand All @@ -646,6 +660,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
If no value is provided, then standard attention is used.
Setting the threshold to a higher value generally increases kernel performance at the cost of accuracy degradation.
The actual threshold value equals the provided threshold_scale_factor divided by the context length.
Not supported together with ``cum_seq_lens_q``.
backend : str = "auto"
The implementation backend, could be ``auto``/``xqa``, ``trtllm-gen``, or ``cute-dsl``. Defaults to ``auto``.
When set to ``auto``, the backend will be chosen based on the device architecture and kernel availability.
Expand All @@ -660,6 +675,19 @@ def trtllm_batch_decode_with_kv_cache_mla(
True (default) uses vLLM/FlashInfer layout with a 2D page table.
False uses TRT-LLM layout with a 3D page table ``[batch_size, 2, max_num_pages_per_seq]``.
False is only supported for trtllm-gen backend.
cum_seq_lens_q : Optional[torch.Tensor] = None
Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``.
Must be a 1D tensor with at least two entries. When ``max_q_len`` is not provided,
this function validates that it starts with 0, ends at ``query.size(0)``, and is
monotonically non-decreasing.
Only supported by trtllm-gen backend.
When provided, ``query`` must have shape ``[total_q, num_heads, head_dim_qk]``.
For best performance, provide ``max_q_len`` together with ``cum_seq_lens_q`` to avoid host-side metadata validation.
max_q_len : Optional[int] = None
Comment thread
saltyminty marked this conversation as resolved.
Maximum query sequence length across all requests when using ``cum_seq_lens_q``.
Only supported by trtllm-gen backend. Provide together with ``cum_seq_lens_q`` to avoid host-side metadata validation.
Must be greater than or equal to the maximum segment length represented by ``cum_seq_lens_q``.
Over-estimation is safe but may waste work; under-estimation is invalid and may produce incorrect output.

Note
----
Expand Down Expand Up @@ -689,7 +717,11 @@ def trtllm_batch_decode_with_kv_cache_mla(
bmm1_scale = bmm1_scale * log2e
if isinstance(bmm2_scale, torch.Tensor):
assert bmm2_scale.dtype == torch.float32
if max_q_len is not None and cum_seq_lens_q is None:
raise ValueError("max_q_len is only supported when cum_seq_lens_q is provided")
if backend == "xqa":
if cum_seq_lens_q is not None or max_q_len is not None:
raise ValueError("XQA MLA does not support cum_seq_lens_q / max_q_len")
if not is_sm12x_supported(query.device):
raise ValueError(
"XQA MLA requires SM120a (CUDA >= 12.8) or SM121a (CUDA >= 13.0)"
Expand Down Expand Up @@ -734,8 +766,6 @@ def trtllm_batch_decode_with_kv_cache_mla(
enable_pdl = (
device_support_pdl(query.device) if enable_pdl is None else enable_pdl
)
run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)

# Extract block_size (works for both 3D and 4D)
block_size = kv_cache.size(-2)
Expand All @@ -747,6 +777,74 @@ def trtllm_batch_decode_with_kv_cache_mla(
if skip_softmax_threshold_scale_factor is not None and sparse_mla_top_k != 0:
raise ValueError("skip_softmax is not supported for sparse MLA")

has_var_q = cum_seq_lens_q is not None
if has_var_q and sparse_mla_top_k != 0:
raise ValueError(
"sparse MLA (sparse_mla_top_k > 0) is not supported with "
"variable-length queries (cum_seq_lens_q) for trtllm-gen"
)
if has_var_q and skip_softmax_threshold_scale_factor is not None:
raise ValueError(
"skip_softmax is not supported with variable-length queries "
"(cum_seq_lens_q) for trtllm-gen MLA"
)
if has_var_q:
if query.ndim != 3:
raise ValueError(
"query must have shape [total_q, num_heads, head_dim_qk] "
"when cum_seq_lens_q is provided"
)
check_shape_dtype_device(
cum_seq_lens_q,
None,
torch.int32,
query.device,
"cum_seq_lens_q",
)
if cum_seq_lens_q.ndim != 1:
raise ValueError(
f"Expected cum_seq_lens_q.ndim == 1, got {cum_seq_lens_q.ndim}"
)
if cum_seq_lens_q.size(0) < 2:
raise ValueError("cum_seq_lens_q must contain at least two entries")
batch_size = cum_seq_lens_q.size(0) - 1
if batch_size != seq_lens.size(0):
raise ValueError(
"Batch size mismatch: cum_seq_lens_q describes "
f"{batch_size} sequences, but seq_lens has "
f"{seq_lens.size(0)} entries"
)
if max_q_len is None:
cum_seq_lens_q_host = cum_seq_lens_q.cpu()
if cum_seq_lens_q_host[0].item() != 0:
raise ValueError("cum_seq_lens_q must start with 0")
if cum_seq_lens_q_host[-1].item() != query.size(0):
raise ValueError(
"cum_seq_lens_q[-1] must match the flattened query length"
)
q_lens = cum_seq_lens_q_host[1:] - cum_seq_lens_q_host[:-1]
if torch.any(q_lens < 0).item():
raise ValueError(
"cum_seq_lens_q must be monotonically non-decreasing"
)
max_q_len = q_lens.max().item()
if max_q_len <= 0:
raise ValueError(
"cum_seq_lens_q must describe at least one query token"
)
elif max_q_len <= 0:
raise ValueError("max_q_len must be greater than 0")
elif max_q_len > query.size(0):
raise ValueError("max_q_len cannot exceed the flattened query length")
else:
if query.ndim != 4:
raise ValueError(
"query must have shape [batch_size, q_len_per_request, "
"num_heads, head_dim_qk] when cum_seq_lens_q is not provided"
)
batch_size = query.size(0)
max_q_len = query.size(1)

# Validate and normalize to 4D
kv_cache = _check_trtllm_gen_mla_shape(
query,
Expand All @@ -757,8 +855,13 @@ def trtllm_batch_decode_with_kv_cache_mla(
block_tables,
block_size,
uses_shared_paged_kv_idx,
batch_size=batch_size,
max_q_len=max_q_len,
)

run_func = get_trtllm_gen_fmha_module().trtllm_paged_attention_decode
sm_count = get_device_sm_count(query.device)

expected_out_shape = query.shape[:-1] + (kv_lora_rank,)
if out is None:
out = torch.empty(
Expand All @@ -773,9 +876,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
"out",
)

batch_size = query.size(0)
max_q_len = query.size(1)
query = query.flatten(0, 1) # [B*S, H, D]
if not has_var_q:
query = query.flatten(0, 1) # [B*S, H, D]

run_func(
out,
Expand All @@ -800,7 +902,7 @@ def trtllm_batch_decode_with_kv_cache_mla(
enable_pdl,
workspace_buffer.numel() * workspace_buffer.element_size(),
sinks,
None, # cum_seq_lens_q
cum_seq_lens_q,
None, # key_block_scales
None, # value_block_scales
skip_softmax_threshold_scale_factor,
Expand Down Expand Up @@ -846,6 +948,8 @@ def trtllm_batch_decode_with_kv_cache_mla(
"cute-dsl backend (MLA decode kernel) does not support separate KV page indices "
"(uses_shared_paged_kv_idx=False)"
)
if cum_seq_lens_q is not None:
raise ValueError("cute-dsl MLA does not support cum_seq_lens_q")

return cute_dsl_mla_decode(
query=query,
Expand Down
Loading
Loading