diff --git a/flashinfer/mla/_core.py b/flashinfer/mla/_core.py index 8308da1fb5..b39865ba7e 100644 --- a/flashinfer/mla/_core.py +++ b/flashinfer/mla/_core.py @@ -145,9 +145,21 @@ def _check_trtllm_gen_mla_shape( page_table: torch.Tensor, page_size: int, uses_shared_paged_kv_idx: bool = True, + batch_size: Optional[int] = None, + max_q_len: Optional[int] = None, ) -> torch.Tensor: - if query.ndim != 4: - raise ValueError(f"Expected query.ndim == 4, got {query.ndim}") + if query.ndim == 4: + num_seqs, num_tokens, _, qk_head_dim = query.shape + elif query.ndim == 3: + if batch_size is None or max_q_len is None: + raise ValueError( + "batch_size and max_q_len are required when query.ndim == 3" + ) + num_seqs = batch_size + num_tokens = max_q_len + _, _, qk_head_dim = query.shape + else: + raise ValueError(f"Expected query.ndim == 3 or 4, got {query.ndim}") # Support both 3D and 4D kv_cache for backward compatibility if kv_cache.ndim == 3: @@ -169,7 +181,6 @@ def _check_trtllm_gen_mla_shape( f"Unsupported MLA dimensions, got kv_lora_rank={kv_lora_rank} and qk_rope_head_dim={qk_rope_head_dim}, supported dimensions are: {supported_mla_head_dimensions}" ) - num_seqs, num_tokens, _, qk_head_dim = query.shape ckv_dim = kv_cache.shape[3] expected_qk_head_dim = kv_lora_rank + qk_rope_head_dim if qk_head_dim != expected_qk_head_dim or ckv_dim != expected_qk_head_dim: @@ -1661,6 +1672,8 @@ def trtllm_batch_decode_with_kv_cache_mla( lse: Optional[torch.Tensor] = None, return_lse: bool = False, cute_dsl_impl: str = "auto", + cum_seq_lens_q: Optional[torch.Tensor] = None, + max_q_len: Optional[int] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: """ Parameters @@ -1672,6 +1685,7 @@ def trtllm_batch_decode_with_kv_cache_mla( kv_lora_rank: kv_lora_rank, must be 512 or 256 qk_rope_head_dim: qk_rope_head_dim, must be 64 sparse_mla_top_k: sparse MLA top k, must be 0 for non-sparse MLA. + Not supported together with ``cum_seq_lens_q``. block_tables: page table of kv cache. When ``uses_shared_paged_kv_idx`` is True (default): shape ``[batch_size, max_num_pages_per_seq]``. When ``uses_shared_paged_kv_idx`` is False: shape ``[batch_size, 2, max_num_pages_per_seq]`` @@ -1743,6 +1757,23 @@ def trtllm_batch_decode_with_kv_cache_mla( * ``"monolithic"`` — strict. Always run the monolithic kernels; raise :class:`ValueError` if the call uses any modular-only feature (e.g. ``sinks``). + cum_seq_lens_q : Optional[torch.Tensor] = None + Cumulative query sequence lengths for variable-length query support, + shape ``[batch_size + 1]``, dtype ``torch.int32``. Must be a 1D tensor + with at least two entries. When ``max_q_len`` is not provided, this + function validates that it starts with 0, ends at ``query.size(0)``, + and is monotonically non-decreasing. Only supported by the + ``trtllm-gen`` backend. When provided, ``query`` must have shape + ``[total_q, num_heads, head_dim_qk]``. + For best performance, provide ``max_q_len`` together with + ``cum_seq_lens_q`` to avoid host-side metadata validation. + max_q_len : Optional[int] = None + Maximum query sequence length across all requests when using + ``cum_seq_lens_q``. Provide with ``cum_seq_lens_q`` to avoid + host-side metadata validation. Must be greater than or equal to the + maximum segment length represented by ``cum_seq_lens_q``. Over-estimation + is safe but may waste work; under-estimation is invalid and may produce + incorrect output. Note ---- @@ -1789,11 +1820,15 @@ def trtllm_batch_decode_with_kv_cache_mla( if isinstance(bmm2_scale, torch.Tensor): if bmm2_scale.dtype != torch.float32: raise TypeError("bmm2_scale tensor must have dtype torch.float32") + if max_q_len is not None and cum_seq_lens_q is None: + raise ValueError("max_q_len is only supported when cum_seq_lens_q is provided") if backend == "auto" and get_compute_capability(query.device)[0] != 10: backend = "xqa" if backend == "xqa": + if cum_seq_lens_q is not None or max_q_len is not None: + raise ValueError("XQA MLA does not support cum_seq_lens_q / max_q_len") if not is_sm12x_supported(query.device): raise ValueError( "XQA MLA requires SM120a (CUDA >= 12.8) or SM121a (CUDA >= 13.0)" @@ -1838,7 +1873,6 @@ def trtllm_batch_decode_with_kv_cache_mla( sinks, enable_pdl, ) - if backend not in ("auto", "trtllm-gen", "cute-dsl"): raise ValueError(f"Backend {backend} not supported") @@ -1858,6 +1892,125 @@ def trtllm_batch_decode_with_kv_cache_mla( if skip_softmax_threshold_scale_factor is not None and sparse_mla_top_k != 0: raise ValueError("skip_softmax is not supported for sparse MLA") + has_var_q = cum_seq_lens_q is not None + if has_var_q: + if backend == "cute-dsl": + raise ValueError("cute-dsl MLA does not support cum_seq_lens_q") + if sparse_mla_top_k != 0: + raise ValueError( + "sparse MLA (sparse_mla_top_k > 0) is not supported with " + "variable-length queries (cum_seq_lens_q) for trtllm-gen" + ) + if return_lse or lse is not None: + raise NotImplementedError( + "trtllm-gen MLA does not support return_lse/lse with cum_seq_lens_q" + ) + if query.ndim != 3: + raise ValueError( + "query must have shape [total_q, num_heads, head_dim_qk] " + "when cum_seq_lens_q is provided" + ) + check_shape_dtype_device( + cum_seq_lens_q, + None, + torch.int32, + query.device, + "cum_seq_lens_q", + ) + if cum_seq_lens_q.ndim != 1: + raise ValueError( + f"Expected cum_seq_lens_q.ndim == 1, got {cum_seq_lens_q.ndim}" + ) + if cum_seq_lens_q.size(0) < 2: + raise ValueError("cum_seq_lens_q must contain at least two entries") + batch_size = cum_seq_lens_q.size(0) - 1 + if batch_size != seq_lens.size(0): + raise ValueError( + "Batch size mismatch: cum_seq_lens_q describes " + f"{batch_size} sequences, but seq_lens has {seq_lens.size(0)} entries" + ) + if max_q_len is None: + cum_seq_lens_q_host = cum_seq_lens_q.cpu() + if cum_seq_lens_q_host[0].item() != 0: + raise ValueError("cum_seq_lens_q must start with 0") + if cum_seq_lens_q_host[-1].item() != query.size(0): + raise ValueError( + "cum_seq_lens_q[-1] must match the flattened query length" + ) + q_lens = cum_seq_lens_q_host[1:] - cum_seq_lens_q_host[:-1] + if torch.any(q_lens < 0).item(): + raise ValueError("cum_seq_lens_q must be monotonically non-decreasing") + max_q_len = q_lens.max().item() + if max_q_len <= 0: + raise ValueError( + "cum_seq_lens_q must describe at least one query token" + ) + elif max_q_len <= 0: + raise ValueError("max_q_len must be greater than 0") + elif max_q_len > query.size(0): + raise ValueError("max_q_len cannot exceed the flattened query length") + + kv_cache = _check_trtllm_gen_mla_shape( + query, + kv_cache, + kv_lora_rank, + qk_rope_head_dim, + sparse_mla_top_k, + block_tables, + block_size, + uses_shared_paged_kv_idx, + batch_size=batch_size, + max_q_len=max_q_len, + ) + + expected_out_shape = query.shape[:-1] + (kv_lora_rank,) + if out is None: + out = torch.empty( + expected_out_shape, dtype=torch.bfloat16, device=query.device + ) + else: + check_shape_dtype_device( + out, + expected_out_shape, + torch.bfloat16, + query.device, + "out", + ) + + get_trtllm_gen_fmha_module().trtllm_paged_attention_decode( + out, + None, # fp4 output (unsupported by wrapper) + query, + kv_cache, + kv_cache, + workspace_buffer, + block_tables, + seq_lens, + max_q_len, + max_seq_len, + bmm1_scale, + bmm2_scale, + -1, # o_sf_scale + -1, # o_sf_vec_size + 0, # o_sf_start_index + batch_size, + -1, # window_left + sparse_mla_top_k, + sm_count, + enable_pdl, + workspace_buffer.numel() * workspace_buffer.element_size(), + sinks, + cum_seq_lens_q, + None, # key_block_scales + None, # value_block_scales + skip_softmax_threshold_scale_factor, + uses_shared_paged_kv_idx, + None, # lse + 0, # lse_stride_tokens + 0, # lse_stride_heads + ) + return out + # Normalize kv_cache to 4D and validate MLA dimensions. Despite the name, # the shape/dim checks here apply to both backends. kv_cache = _check_trtllm_gen_mla_shape( diff --git a/flashinfer/trace/templates/attention.py b/flashinfer/trace/templates/attention.py index 96a891d99b..a48a3824f3 100644 --- a/flashinfer/trace/templates/attention.py +++ b/flashinfer/trace/templates/attention.py @@ -1715,10 +1715,26 @@ def _trtllm_batch_decode_mla_reference( """Reference for trtllm_batch_decode_with_kv_cache_mla. Query is concatenated [Q_nope, Q_pe] along the head_dim axis; the KV - cache is [ckv ‖ kpe]. Output is the K_nope-projected attention - (``[batch, q_len, num_heads, kv_lora_rank]``). + cache is [ckv ‖ kpe]. Dense calls return the K_nope-projected + attention as ``[batch, q_len, num_heads, kv_lora_rank]``; ragged + calls return ``[num_tokens, num_heads, kv_lora_rank]``. """ - batch_size, q_len, num_heads, head_dim_qk = query.shape + cum_seq_lens_q = kwargs.get("cum_seq_lens_q") + if cum_seq_lens_q is None: + batch_size, q_len, num_heads, head_dim_qk = query.shape + output = torch.zeros( + (batch_size, q_len, num_heads, kv_lora_rank), + dtype=query.dtype, + device=query.device, + ) + else: + batch_size = cum_seq_lens_q.numel() - 1 + num_heads, head_dim_qk = query.shape[1:] + output = torch.zeros( + (*query.shape[:-1], kv_lora_rank), + dtype=query.dtype, + device=query.device, + ) assert head_dim_qk == kv_lora_rank + qk_rope_head_dim bmm1_scale = kwargs.get("bmm1_scale", 1.0) bmm1_scale = ( @@ -1733,11 +1749,6 @@ def _trtllm_batch_decode_mla_reference( if kv_cache.dim() == 4: kv_cache = kv_cache.squeeze(1) page_size = kv_cache.shape[1] - output = torch.zeros( - (batch_size, q_len, num_heads, kv_lora_rank), - dtype=query.dtype, - device=query.device, - ) for b in range(batch_size): kv_len = int(seq_lens[b].item()) n_pages = (kv_len + page_size - 1) // page_size @@ -1746,13 +1757,24 @@ def _trtllm_batch_decode_mla_reference( # MLA split: first kv_lora_rank dims = ckv (K_nope), last qk_rope_head_dim dims = kpe Kn = flat[:, :kv_lora_rank] Kp = flat[:, kv_lora_rank:] - for t in range(q_len): - q = query[b, t].to(torch.float32) # [num_heads, head_dim_qk] + if cum_seq_lens_q is None: + q_start = 0 + q_end = q_len + q_batch = query[b] + else: + q_start = int(cum_seq_lens_q[b].item()) + q_end = int(cum_seq_lens_q[b + 1].item()) + q_batch = query[q_start:q_end] + for t in range(q_end - q_start): + q = q_batch[t].to(torch.float32) # [num_heads, head_dim_qk] Qn = q[:, :kv_lora_rank] # [num_heads, kv_lora_rank] Qp = q[:, kv_lora_rank:] # [num_heads, qk_rope_head_dim] logits = (Qn @ Kn.T + Qp @ Kp.T) * bmm1_scale attn = torch.softmax(logits, dim=-1) - output[b, t] = (attn @ Kn * bmm2_scale).to(query.dtype) + if cum_seq_lens_q is None: + output[b, t] = (attn @ Kn * bmm2_scale).to(query.dtype) + else: + output[q_start + t] = (attn @ Kn * bmm2_scale).to(query.dtype) return output @@ -1943,13 +1965,66 @@ def _trtllm_batch_decode_mla_sparse_init( } -trtllm_batch_decode_mla_trace = TraceTemplate( +def _trtllm_batch_decode_mla_ragged_init( + *, + batch_size: int, + q_len_per_request: int, + num_heads: int = 128, + head_dim_qk: int = 576, + kv_lora_rank: int = 512, + qk_rope_head_dim: int = 64, + qk_nope_head_dim: int = 512, + num_pages: int = 0, + kv_pad_dim: int = 1, + page_size: int = 64, + max_pages_per_seq: int = 0, + workspace_size: int = 256 << 20, + device: str = "cuda", + seed: int = 0, +): + dense = _trtllm_batch_decode_mla_init( + batch_size=batch_size, + q_len_per_request=q_len_per_request, + num_heads=num_heads, + head_dim_qk=head_dim_qk, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + qk_nope_head_dim=qk_nope_head_dim, + num_pages=num_pages, + kv_pad_dim=kv_pad_dim, + page_size=page_size, + max_pages_per_seq=max_pages_per_seq, + workspace_size=workspace_size, + device=device, + seed=seed, + ) + query = dense["query"] + q_lens = ( + torch.arange(batch_size, device=query.device, dtype=torch.int32) + % q_len_per_request + ) + 1 + q_lens[-1] = q_len_per_request + cum_seq_lens_q = torch.empty(batch_size + 1, device=query.device, dtype=torch.int32) + cum_seq_lens_q[0] = 0 + cum_seq_lens_q[1:] = torch.cumsum(q_lens, dim=0) + dense["query"] = torch.cat( + [query[i, : int(q_lens[i].item())] for i in range(batch_size)], + dim=0, + ) + dense["cum_seq_lens_q"] = cum_seq_lens_q + dense["max_q_len"] = int(q_len_per_request) + return dense + + +trtllm_batch_decode_mla_dense_trace = TraceTemplate( op_type="mla_paged", - name_prefix="trtllm_batch_decode_mla", + name_prefix="trtllm_batch_decode_mla_dense", description=( "SM100+ TRT-LLM MLA paged decode. Query is concatenated [Q_nope, " "Q_pe] with head_dim_qk = kv_lora_rank + qk_rope_head_dim; KV cache " - "is [ckv ‖ kpe]. Output dim equals kv_lora_rank." + "is [ckv ‖ kpe]. Dense API calls pass " + "[batch_size, q_len_per_request, num_heads, head_dim_qk] and return " + "[batch_size, q_len_per_request, num_heads, kv_lora_rank]." ), axes={ "batch_size": Var(), @@ -1971,7 +2046,7 @@ def _trtllm_batch_decode_mla_sparse_init( inputs={ "query": Tensor( ["batch_size", "q_len_per_request", "num_heads", "head_dim_qk"], - description="Concatenated [Q_nope, Q_pe] query.", + description="Concatenated [Q_nope, Q_pe] dense query.", ), "kv_cache": Tensor( ["num_pages", "kv_pad_dim", "page_size", "head_dim_qk"], @@ -2017,6 +2092,7 @@ def _trtllm_batch_decode_mla_sparse_init( "output": Tensor( ["batch_size", "q_len_per_request", "num_heads", "kv_lora_rank"], dtype_from="query", + description="Dense MLA output.", ), }, tags=["status:verified", "stage:decode", "backend:trtllm", "mla"], @@ -2025,6 +2101,96 @@ def _trtllm_batch_decode_mla_sparse_init( ) +trtllm_batch_decode_mla_ragged_trace = TraceTemplate( + op_type="mla_paged", + name_prefix="trtllm_batch_decode_mla_ragged", + description=( + "SM100+ TRT-LLM MLA paged decode for variable query lengths. Query is " + "concatenated [Q_nope, Q_pe] in flattened " + "[num_tokens, num_heads, head_dim_qk] form with cum_seq_lens_q. Output " + "dim equals kv_lora_rank." + ), + axes={ + "batch_size": Var(), + "num_tokens": Var(description="Total query tokens for variable query lengths."), + "batch_size_plus_1": Var(description="batch_size + 1 for cum_seq_lens_q."), + "num_heads": Const(abbrev="h"), + "head_dim_qk": Const(abbrev="d_qk"), + "kv_lora_rank": Const(abbrev="ckv"), + "qk_rope_head_dim": Const(abbrev="kpe"), + "qk_nope_head_dim": Const(abbrev="nope"), + "num_pages": Var(), + "kv_pad_dim": Const( + abbrev="", + description="Always 1; backwards-compat singleton dim in the rank-4 kv_cache layout.", + ), + "page_size": Const(abbrev="ps"), + "max_pages_per_seq": Var(), + "workspace_size": Var(description="Workspace buffer length in bytes."), + }, + inputs={ + "query": Tensor( + ["num_tokens", "num_heads", "head_dim_qk"], + description="Concatenated [Q_nope, Q_pe] flattened ragged query.", + ), + "kv_cache": Tensor( + ["num_pages", "kv_pad_dim", "page_size", "head_dim_qk"], + description=( + "Paged KV cache [ckv ‖ kpe]. The kernel accepts both the 3D " + "[num_pages, page_size, head_dim_qk] layout and the rank-4 " + "[num_pages, 1, page_size, head_dim_qk] layout for backwards " + "compatibility; this template models the rank-4 form." + ), + ), + "workspace_buffer": Tensor( + ["workspace_size"], + dtype="uint8", + description="Workspace scratch (flat byte buffer).", + ), + "qk_nope_head_dim": Scalar("int32"), + "kv_lora_rank": Scalar("int32"), + "qk_rope_head_dim": Scalar("int32"), + "block_tables": Tensor( + ["batch_size", "max_pages_per_seq"], + dtype="int32", + description="Page table mapping per sequence.", + ), + "seq_lens": Tensor(["batch_size"], dtype="int32"), + "max_seq_len": Scalar("int32"), + "bmm1_scale": Scalar( + "float32", + optional=True, + description="Fused scale applied after Q @ K^T (includes 1/sqrt(head_dim_qk)).", + ), + "bmm2_scale": Scalar( + "float32", + optional=True, + description="Scale applied after softmax @ V.", + ), + "cum_seq_lens_q": Tensor( + ["batch_size_plus_1"], + dtype="int32", + description="Cumulative query sequence lengths for variable query lengths.", + ), + "max_q_len": Scalar( + "int32", + optional=True, + description="Maximum query sequence length when cum_seq_lens_q is provided.", + ), + }, + outputs={ + "output": Tensor( + ["num_tokens", "num_heads", "kv_lora_rank"], + dtype_from="query", + description="Flattened ragged MLA output.", + ), + }, + tags=["status:verified", "stage:decode", "backend:trtllm", "mla"], + reference=_trtllm_batch_decode_mla_reference, + init=_trtllm_batch_decode_mla_ragged_init, +) + + trtllm_batch_decode_mla_sparse_trace = TraceTemplate( op_type="mla_paged", name_prefix="trtllm_batch_decode_mla_sparse", @@ -2121,11 +2287,14 @@ def trtllm_batch_decode_mla_trace_dispatch(**kwargs): sparse_mla_top_k = int(kwargs.get("sparse_mla_top_k", 0) or 0) if sparse_mla_top_k > 0: return trtllm_batch_decode_mla_sparse_trace - return trtllm_batch_decode_mla_trace + if kwargs.get("cum_seq_lens_q") is None: + return trtllm_batch_decode_mla_dense_trace + return trtllm_batch_decode_mla_ragged_trace trtllm_batch_decode_mla_trace_dispatch.templates = [ # type: ignore[attr-defined] - trtllm_batch_decode_mla_trace, + trtllm_batch_decode_mla_dense_trace, + trtllm_batch_decode_mla_ragged_trace, trtllm_batch_decode_mla_sparse_trace, ] diff --git a/tests/attention/test_trtllm_gen_mla.py b/tests/attention/test_trtllm_gen_mla.py index fd0fefad89..ed18359222 100755 --- a/tests/attention/test_trtllm_gen_mla.py +++ b/tests/attention/test_trtllm_gen_mla.py @@ -292,6 +292,7 @@ def trtllm_batch_decode_mla( MAX_SEQ_LEN: int, skips_softmax: bool, uses_shared_paged_kv_idx: bool = True, + use_cum_seq_lens_q: bool = False, ): compute_capability = get_compute_capability(torch.device(device="cuda")) if backend == "xqa": @@ -322,6 +323,8 @@ def trtllm_batch_decode_mla( if skips_softmax and backend != "trtllm-gen": pytest.skip("skips_softmax is only supported for trtllm-gen backend") + if use_cum_seq_lens_q and backend != "trtllm-gen": + pytest.skip("cum_seq_lens_q is only supported for trtllm-gen backend") torch.manual_seed(42) device = "cuda:0" @@ -337,6 +340,29 @@ def trtllm_batch_decode_mla( qk_head_dim, device=device, ).to(dtype) + if use_cum_seq_lens_q: + q_lens = ( + torch.arange(batch_size, device=device, dtype=torch.int32) + % q_len_per_request + ) + 1 + q_lens[-1] = q_len_per_request + cum_seq_lens_q = torch.empty(batch_size + 1, device=device, dtype=torch.int32) + cum_seq_lens_q[0] = 0 + cum_seq_lens_q[1:] = torch.cumsum(q_lens, dim=0) + query_input = torch.cat( + [query[i, : int(q_lens[i].item())] for i in range(batch_size)], + dim=0, + ) + # Overestimate to verify CUDA-graph-friendly max_q_len contracts. + # Exact max is q_lens.max(). + max_q_len = min(q_len_per_request + 1, query_input.size(0)) + else: + query_input = query + cum_seq_lens_q = None + q_lens = torch.full( + (batch_size,), q_len_per_request, device=device, dtype=torch.int32 + ) + max_q_len = None num_tokens = MAX_SEQ_LEN * batch_size num_blocks = (num_tokens + page_size - 1) // page_size @@ -429,7 +455,10 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: # Only the trtllm-gen MLA path supports LSE output; other backends raise NotImplementedError. check_lse = ( - backend == "trtllm-gen" and not skips_softmax and dtype != torch.float8_e4m3fn + backend == "trtllm-gen" + and not skips_softmax + and not use_cum_seq_lens_q + and dtype != torch.float8_e4m3fn ) softmax_end = None guard_end = None @@ -458,7 +487,7 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: # Run decode-MLA output_and_lse = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( - query=query, + query=query_input, kv_cache=kv_cache.unsqueeze(1), workspace_buffer=workspace_buffer, qk_nope_head_dim=layer_dimensions.head_dimensions.qk_nope_head_dim, @@ -475,6 +504,8 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, lse=provided_lse, return_lse=check_lse, + cum_seq_lens_q=cum_seq_lens_q, + max_q_len=max_q_len, ) if check_lse: output, lse_out = output_and_lse @@ -513,10 +544,13 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: query = query.to(torch.bfloat16) kv_cache = kv_cache.to(torch.bfloat16) - q_indptr = ( - torch.arange(0, batch_size + 1, device=device, dtype=torch.int32) - * q_len_per_request - ) + q_ref = query_input if use_cum_seq_lens_q else query + if dtype == torch.float8_e4m3fn and use_cum_seq_lens_q: + q_ref = q_ref.to(torch.bfloat16) + + q_indptr = torch.empty(batch_size + 1, device=device, dtype=torch.int32) + q_indptr[0] = 0 + q_indptr[1:] = torch.cumsum(q_lens, dim=0) kv_indptr = torch.zeros_like(q_indptr) kv_indptr[1:] = torch.cumsum(blocks_per_seq, dim=0) kv_indices = all_block_ids.int() @@ -532,16 +566,16 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: page_size, True, sm_scale, - query.dtype, + q_ref.dtype, kv_cache.dtype, ) - q_nope = query[..., : layer_dimensions.head_dimensions.kv_lora_rank].view( - batch_size * q_len_per_request, + q_nope = q_ref[..., : layer_dimensions.head_dimensions.kv_lora_rank].reshape( + -1, layer_dimensions.num_heads, layer_dimensions.head_dimensions.kv_lora_rank, ) - q_pe = query[..., layer_dimensions.head_dimensions.kv_lora_rank :].view( - batch_size * q_len_per_request, + q_pe = q_ref[..., layer_dimensions.head_dimensions.kv_lora_rank :].reshape( + -1, layer_dimensions.num_heads, layer_dimensions.head_dimensions.qk_rope_head_dim, ) @@ -561,9 +595,16 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: assert not torch.isnan(o_ref).any(), "o_ref is nan" assert not torch.isnan(output).any(), "output is nan" - o_ref_view = o_ref.view( - batch_size, q_len_per_request, layer_dimensions.num_heads, -1 - ) + if use_cum_seq_lens_q: + output_view = output + o_ref_view = o_ref + else: + output_view = output.reshape( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ) + o_ref_view = o_ref.view( + batch_size, q_len_per_request, layer_dimensions.num_heads, -1 + ) if dtype == torch.float8_e4m3fn: rtol, atol = 1e-1, 1e-1 @@ -571,7 +612,7 @@ def maybe_get_lse_guard_end(softmax_end: int) -> int | None: rtol, atol = 1e-2, 1e-2 try: - torch.testing.assert_close(output, o_ref_view, rtol=rtol, atol=atol) + torch.testing.assert_close(output_view, o_ref_view, rtol=rtol, atol=atol) except AssertionError as fa2_err: if backend == "cute-dsl": # fa2 reference may diverge from cute-dsl in some configs; @@ -887,6 +928,7 @@ def trtllm_batch_decode_mla_sparse( @pytest.mark.parametrize("backend", ["trtllm-gen", "xqa", "cute-dsl"]) @pytest.mark.parametrize("skips_softmax", [False, True]) @pytest.mark.parametrize("uses_shared_paged_kv_idx", [True, False]) +@pytest.mark.parametrize("use_cum_seq_lens_q", [False, True]) def test_trtllm_batch_decode_mla( layer_dimensions: MLALayerDimensions, batch_size: int, @@ -899,7 +941,10 @@ def test_trtllm_batch_decode_mla( backend: str, skips_softmax: bool, uses_shared_paged_kv_idx: bool, + use_cum_seq_lens_q: bool, ): + if use_cum_seq_lens_q and backend != "trtllm-gen": + pytest.skip("cum_seq_lens_q is only supported for trtllm-gen backend") if backend == "xqa" and layer_dimensions.head_dimensions == smaller_mla_dimensions: pytest.skip("XQA MLA does not support smaller MLA dimensions yet.") if backend == "xqa" and layer_dimensions.num_heads != 128: @@ -927,6 +972,7 @@ def test_trtllm_batch_decode_mla( 1024, skips_softmax, uses_shared_paged_kv_idx=uses_shared_paged_kv_idx, + use_cum_seq_lens_q=use_cum_seq_lens_q, ) @@ -1074,3 +1120,122 @@ def test_trtllm_batch_decode_mla_preallocated_out( ) assert result_pre.shape == expected_shape torch.testing.assert_close(result_none, result_pre, rtol=1e-3, atol=1e-3) + + +def test_trtllm_batch_decode_mla_cum_seq_lens_q_batch_mismatch(): + cc = get_compute_capability(torch.device("cuda")) + if cc[0] != 10: + pytest.skip("trtllm-gen MLA requires SM100/SM103") + + device = "cuda:0" + layer_dim = supported_mla_layer_dimensions[0] + kv_lora_rank = layer_dim.head_dimensions.kv_lora_rank + qk_nope_head_dim = layer_dim.head_dimensions.qk_nope_head_dim + qk_rope_head_dim = layer_dim.head_dimensions.qk_rope_head_dim + num_heads = layer_dim.num_heads + head_dim_qk = kv_lora_rank + qk_rope_head_dim + + page_size = 64 + max_seq_len = 64 + batch_size = 2 + num_pages_per_seq = (max_seq_len + page_size - 1) // page_size + kv_cache = torch.randn( + num_pages_per_seq * batch_size, + 1, + page_size, + head_dim_qk, + dtype=torch.bfloat16, + device=device, + ) + block_tables = torch.arange( + num_pages_per_seq * batch_size, + device=device, + dtype=torch.int32, + ).reshape(batch_size, num_pages_per_seq) + seq_lens = torch.full((batch_size,), max_seq_len, device=device, dtype=torch.int32) + query = torch.randn(5, num_heads, head_dim_qk, dtype=torch.bfloat16, device=device) + cum_seq_lens_q = torch.tensor([0, 5], device=device, dtype=torch.int32) + + global global_trtllm_gen_fmha_workspace_buffer + if global_trtllm_gen_fmha_workspace_buffer is None: + global_trtllm_gen_fmha_workspace_buffer = torch.zeros( + workspace_size, + dtype=torch.int8, + device=device, + ) + + with pytest.raises(ValueError, match="Batch size mismatch"): + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=query, + kv_cache=kv_cache, + workspace_buffer=global_trtllm_gen_fmha_workspace_buffer, + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=block_tables, + seq_lens=seq_lens, + max_seq_len=max_seq_len, + bmm1_scale=1.0 / (head_dim_qk**0.5), + bmm2_scale=1.0, + backend="trtllm-gen", + enable_pdl=False, + cum_seq_lens_q=cum_seq_lens_q, + max_q_len=5, + ) + + +def test_trtllm_batch_decode_mla_max_q_len_requires_cum_seq_lens_q(): + with pytest.raises( + ValueError, match="max_q_len is only supported when cum_seq_lens_q is provided" + ): + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=torch.empty(1, 1, 1, 576, dtype=torch.bfloat16), + kv_cache=torch.empty(1, 1, 64, 576, dtype=torch.bfloat16), + workspace_buffer=torch.empty(1024, dtype=torch.int8), + qk_nope_head_dim=512, + kv_lora_rank=512, + qk_rope_head_dim=64, + block_tables=torch.zeros(1, 1, dtype=torch.int32), + seq_lens=torch.ones(1, dtype=torch.int32), + max_seq_len=64, + backend="trtllm-gen", + max_q_len=1, + ) + + +def test_trtllm_batch_decode_mla_sparse_rejects_cum_seq_lens_q(): + cc = get_compute_capability(torch.device("cuda")) + if cc[0] != 10: + pytest.skip("trtllm-gen MLA requires SM100/SM103") + + device = "cuda:0" + layer_dim = supported_mla_layer_dimensions[0] + kv_lora_rank = layer_dim.head_dimensions.kv_lora_rank + qk_nope_head_dim = layer_dim.head_dimensions.qk_nope_head_dim + qk_rope_head_dim = layer_dim.head_dimensions.qk_rope_head_dim + num_heads = layer_dim.num_heads + head_dim_qk = kv_lora_rank + qk_rope_head_dim + + with pytest.raises(ValueError, match=r"sparse MLA .* variable-length queries"): + flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla( + query=torch.empty( + 2, num_heads, head_dim_qk, dtype=torch.bfloat16, device=device + ), + kv_cache=torch.empty( + 1, 1, 64, head_dim_qk, dtype=torch.bfloat16, device=device + ), + workspace_buffer=torch.empty( + workspace_size, dtype=torch.int8, device=device + ), + qk_nope_head_dim=qk_nope_head_dim, + kv_lora_rank=kv_lora_rank, + qk_rope_head_dim=qk_rope_head_dim, + block_tables=torch.zeros(1, 1, 1, dtype=torch.int32, device=device), + seq_lens=torch.ones(1, dtype=torch.int32, device=device), + max_seq_len=64, + sparse_mla_top_k=1, + backend="trtllm-gen", + enable_pdl=False, + cum_seq_lens_q=torch.tensor([0, 2], dtype=torch.int32, device=device), + max_q_len=2, + ) diff --git a/tests/trace/test_fi_trace.py b/tests/trace/test_fi_trace.py index 17b530e0cc..24afdf5fad 100644 --- a/tests/trace/test_fi_trace.py +++ b/tests/trace/test_fi_trace.py @@ -617,6 +617,68 @@ def test_usecase_sampling_vocab_discovery(): assert parsed["outputs"]["samples"]["dtype"] == "int64" +def test_trtllm_batch_decode_mla_fi_trace_dense_and_ragged(): + import flashinfer.mla + + common = { + "kv_cache": torch.empty(4, 64, 576, dtype=torch.bfloat16), + "workspace_buffer": torch.empty(1024, dtype=torch.int8), + "qk_nope_head_dim": 512, + "kv_lora_rank": 512, + "qk_rope_head_dim": 64, + "block_tables": torch.zeros(2, 1, dtype=torch.int32), + "seq_lens": torch.full((2,), 64, dtype=torch.int32), + "max_seq_len": 64, + } + + dense = flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla.fi_trace( + query=torch.empty(2, 3, 128, 576, dtype=torch.bfloat16), + **common, + ) + _check_defn( + dense, + "mla_paged", + "flashinfer.mla._core.trtllm_batch_decode_with_kv_cache_mla", + ) + assert dense["name"].startswith("trtllm_batch_decode_mla_dense") + assert dense["inputs"]["query"]["shape"] == [ + "batch_size", + "q_len_per_request", + "num_heads", + "head_dim_qk", + ] + assert dense["outputs"]["output"]["shape"] == [ + "batch_size", + "q_len_per_request", + "num_heads", + "kv_lora_rank", + ] + + ragged = flashinfer.mla.trtllm_batch_decode_with_kv_cache_mla.fi_trace( + query=torch.empty(5, 128, 576, dtype=torch.bfloat16), + cum_seq_lens_q=torch.tensor([0, 2, 5], dtype=torch.int32), + max_q_len=3, + **common, + ) + _check_defn( + ragged, + "mla_paged", + "flashinfer.mla._core.trtllm_batch_decode_with_kv_cache_mla", + ) + assert ragged["name"].startswith("trtllm_batch_decode_mla_ragged") + assert ragged["inputs"]["query"]["shape"] == [ + "num_tokens", + "num_heads", + "head_dim_qk", + ] + assert ragged["outputs"]["output"]["shape"] == [ + "num_tokens", + "num_heads", + "kv_lora_rank", + ] + assert ragged["inputs"]["max_q_len"]["shape"] is None + + # --------------------------------------------------------------------------- # JSON file output # ---------------------------------------------------------------------------