trtllm_batch_decode_with_kv_cache_mla trtllm-gen backend cum_seq_lens_q support#3238
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughAdds optional ragged (per-request variable-length) query support via ChangesVariable-Length MLA Decode (single cohort)
Sequence Diagram(s)sequenceDiagram
participant Test as Test/Client
participant MLA as flashinfer.mla/_core
participant Trace as trace/templates/attention
participant Kernel as TRT-LLM Kernel
Test->>MLA: call trtllm_batch_decode_with_kv_cache_mla(query, cum_seq_lens_q?)
MLA->>MLA: validate shapes, compute max_q_len / has_var_q
MLA->>Trace: select dense or ragged trace based on cum_seq_lens_q
MLA->>Kernel: invoke decode kernel (pass cum_seq_lens_q when ragged)
Kernel-->>MLA: per-token or per-batch MLA outputs
MLA-->>Test: assemble and return outputs
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request introduces support for variable-length (ragged) queries in the trtllm_batch_decode_with_kv_cache_mla function by adding a cum_seq_lens_q parameter. The changes include updated shape validation logic, reference implementation adjustments for ragged tensors, and expanded test coverage. Review feedback highlights a potential rank mismatch in the trace template for dense calls, suggests optimizing host-device synchronizations during tensor validation to improve performance, and recommends adding a consistency check between the query batch size and the provided sequence lengths.
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/trace/templates/attention.py (1)
1245-1287:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftKeep this trace template aligned with the public MLA API, not only the flattened backend form.
trtllm_batch_decode_with_kv_cache_mla()still accepts dense 4Dqueryand returns dense 4Doutputwhencum_seq_lens_qis omitted (flashinfer/mla/_core.py, Lines 798-837), but this template now hardcodes flattened 3Dquery/outputshapes and also modelsworkspace_bufferas["num_pages"]. That makes the existing non-ragged call shape either untraceable or misleading. Please represent both public forms here, or split the ragged backend form into a separate template instead of redefining the existing one.As per coding guidelines,
flashinfer/trace/templates/**/*.py: Add aTraceTemplateinflashinfer/trace/templates/for new or updated APIs to enable trace functionality.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/trace/templates/attention.py` around lines 1245 - 1287, The template currently only models the flattened backend ragged shapes for "query"/"output" and models "workspace_buffer" as ["num_pages"], which breaks or misrepresents the public dense 4D API used by trtllm_batch_decode_with_kv_cache_mla (see MLA caller code that omits cum_seq_lens_q); update the trace templates to either (a) represent both public dense 4D forms and the ragged flattened 3D forms in this template (accepting ["batch_size","q_len","num_heads","head_dim_qk"] and ["num_tokens","num_heads","head_dim_qk"] for "query", and matching shapes for "output"), and model "workspace_buffer" as the correct shape(s) used by the public API, or (b) split into two TraceTemplate entries (one for the public dense API and one for the backend ragged API) and register both; ensure the template names/TraceTemplate declarations match the API (trtllm_batch_decode_with_kv_cache_mla) and include cum_seq_lens_q as optional to disambiguate paths, following the flashinfer/trace/templates/* pattern so trace functionality is enabled.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)
495-508: ⚡ Quick winThis assertion still only validates the uniform-length case.
output.reshape(batch_size, q_len_per_request, ...)assumes every request contributes exactlyq_len_per_requesttokens, souse_cum_seq_lens_q=Trueis currently just checking a flattened dense batch. A launcher bug that ignorescum_seq_lens_qand keeps using fixedq_len_per_requestwould still pass here. Please switch this comparison tocum_seq_lens_q-driven slices and add at least one uneven offset pattern.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_trtllm_gen_mla.py` around lines 495 - 508, The test currently reshapes output and o_ref using q_len_per_request which only validates uniform-length batches; change the comparison to iterate requests using cum_seq_lens_q to compute per-request lengths (e.g., lengths = cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]) and slice output and o_ref along the sequence dimension using those offsets instead of output.reshape; for each request, compare the sliced tensors with torch.testing.assert_close using the existing rtol/atol logic. Also modify the test input to include at least one non-uniform cum_seq_lens_q pattern (uneven offsets) so the uneven-length case is exercised, and keep references to symbols output, o_ref, cum_seq_lens_q, and q_len_per_request when updating the assertions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 1245-1287: The template currently only models the flattened
backend ragged shapes for "query"/"output" and models "workspace_buffer" as
["num_pages"], which breaks or misrepresents the public dense 4D API used by
trtllm_batch_decode_with_kv_cache_mla (see MLA caller code that omits
cum_seq_lens_q); update the trace templates to either (a) represent both public
dense 4D forms and the ragged flattened 3D forms in this template (accepting
["batch_size","q_len","num_heads","head_dim_qk"] and
["num_tokens","num_heads","head_dim_qk"] for "query", and matching shapes for
"output"), and model "workspace_buffer" as the correct shape(s) used by the
public API, or (b) split into two TraceTemplate entries (one for the public
dense API and one for the backend ragged API) and register both; ensure the
template names/TraceTemplate declarations match the API
(trtllm_batch_decode_with_kv_cache_mla) and include cum_seq_lens_q as optional
to disambiguate paths, following the flashinfer/trace/templates/* pattern so
trace functionality is enabled.
---
Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 495-508: The test currently reshapes output and o_ref using
q_len_per_request which only validates uniform-length batches; change the
comparison to iterate requests using cum_seq_lens_q to compute per-request
lengths (e.g., lengths = cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]) and slice
output and o_ref along the sequence dimension using those offsets instead of
output.reshape; for each request, compare the sliced tensors with
torch.testing.assert_close using the existing rtol/atol logic. Also modify the
test input to include at least one non-uniform cum_seq_lens_q pattern (uneven
offsets) so the uneven-length case is exercised, and keep references to symbols
output, o_ref, cum_seq_lens_q, and q_len_per_request when updating the
assertions.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 79e6ad90-8e57-4125-be7f-4440e5c8723f
📥 Commits
Reviewing files that changed from the base of the PR and between ba30d4f and 9496956e50ff2820f1161cf122ccb64ddeb67b55.
📒 Files selected for processing (3)
flashinfer/mla/_core.pyflashinfer/trace/templates/attention.pytests/attention/test_trtllm_gen_mla.py
752aa4e to
e80c099
Compare
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)
760-826:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winAdd explicit rejection or documentation for
sparse_mla_top_kwith variable-length queries.The trtllm-gen backend does not explicitly reject
sparse_mla_top_k > 0whencum_seq_lens_qis provided, nor does it document support for this combination. Currently:
- Tests exercise
cum_seq_lens_qwithout sparse mode, and sparse mode withoutcum_seq_lens_q(separate paths)- Docstrings for both parameters lack cross-references or restrictions
- No validation prevents the combination from reaching the kernel
Either:
- Explicitly reject this combination alongside other backend-specific constraints (e.g., line 760–761 for skip_softmax + sparse), or
- Document the supported page_table layout and add a test exercising sparse + variable-length together
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 760 - 826, The code currently allows sparse_mla_top_k > 0 to proceed when cum_seq_lens_q (variable-length queries) is provided; add an explicit rejection for that combination: after determining has_var_q (the cum_seq_lens_q checks) or alongside the existing skip_softmax check at the top, raise a ValueError when sparse_mla_top_k != 0 and has_var_q is True (clear message like "sparse MLA (sparse_mla_top_k>0) is not supported with variable-length queries (cum_seq_lens_q) for trtllm-gen"). Also update the docstrings for sparse_mla_top_k and cum_seq_lens_q to state this restriction and add a unit test exercising sparse_mla_top_k > 0 with cum_seq_lens_q to ensure the error is raised; reference symbols: sparse_mla_top_k, cum_seq_lens_q, has_var_q, and _check_trtllm_gen_mla_shape.
🧹 Nitpick comments (7)
flashinfer/trace/templates/attention.py (3)
1357-1367: 💤 Low valueDispatch wrapper looks good; consider documenting the
templatesattribute contract.
trtllm_batch_decode_mla_traceis a callable selector rather than aTraceTemplateinstance, withtemplatesattached as a side-channel attribute (type: ignore[attr-defined]). If the tracing infrastructure relies on this attribute for schema discovery (looks that way given the fi_trace test inspecting both names), it would help to either (a) leave a short docstring comment about why the list is attached and who consumes it, or (b) introduce a tiny helper class so the attribute isn't loose. Not a blocker — it works as-is.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/trace/templates/attention.py` around lines 1357 - 1367, The dispatch function trtllm_batch_decode_mla_trace attaches a loose templates attribute for schema discovery which is fragile/undocumented; either add a brief docstring right above trtllm_batch_decode_mla_trace explaining that the tracing infra discovers schema via the templates attribute (and who consumes it), or replace the free function with a tiny helper wrapper (e.g., a small callable class or named object) that exposes templates formally so the attribute isn't a side-channel; reference trtllm_batch_decode_mla_trace and the templates list (containing trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace) when making the change.
1218-1354: 💤 Low valueConsider adding a
head_dim_qk == kv_lora_rank + qk_rope_head_dimconstraint to both MLA traces.The reference implementation enforces this with
assert head_dim_qk == kv_lora_rank + qk_rope_head_dim(Line 1175), but the templates themselves declarehead_dim_qk,kv_lora_rank, andqk_rope_head_dimas independentConstaxes with noconstraints=[...]linking them. Other MLA-style templates in this file (e.g.,gqa_paged_decode_trace) use theconstraintsfield to express such invariants, which helps consumers of the trace schema validate inputs without re-deriving the relationship.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/trace/templates/attention.py` around lines 1218 - 1354, The two MLA trace templates trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace declare head_dim_qk, kv_lora_rank and qk_rope_head_dim independently but must enforce the invariant head_dim_qk == kv_lora_rank + qk_rope_head_dim; update each TraceTemplate's axes to add a constraints entry (similar to gqa_paged_decode_trace) that asserts head_dim_qk == kv_lora_rank + qk_rope_head_dim so consumers can validate inputs consistently with the reference (see the existing assert in _trtllm_batch_decode_mla_reference).
1159-1215: 💤 Low valueReference impl branches look correct; minor robustness note.
The
cum_seq_lens_qbranching cleanly handles both the dense (4D query) and ragged (3D query) layouts, with output shape derived fromquery.shape[:-1] + (kv_lora_rank,)in the ragged case. One small thing worth noting: in the ragged branchq_lenis never bound, so referencing it would be aNameError; this is currently only used in thecum_seq_lens_q is Nonebranch (line 1199), so it's safe but fragile if someone later refactors the inner loop. A defensiveq_len = Noneinitialization in theelsebranch would prevent a future regression.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/trace/templates/attention.py` around lines 1159 - 1215, The ragged-path branch when cum_seq_lens_q is not None never binds q_len, which risks a NameError if the loop logic is refactored; update the else branch (where batch_size, num_heads, head_dim_qk, and output are set) to explicitly initialize q_len = None (or q_len = 0) so references to q_len later (e.g. in the cum_seq_lens_q is None branch and the outer loops using query/q_batch) remain safe; ensure the change is applied near the cum_seq_lens_q handling alongside query and output setup.flashinfer/mla/_core.py (2)
674-678: ⚡ Quick winDocument additional constraints on
cum_seq_lens_q.The validation block enforces that
cum_seq_lens_qis 1Dint32, has at least two entries, starts with 0, ends atquery.size(0), and is monotonically non-decreasing. Surfacing these requirements in the docstring would save users a debugging round-trip.📝 Suggested docstring update
cum_seq_lens_q : Optional[torch.Tensor] = None - Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``. - Only supported by trtllm-gen backend. - When provided, ``query`` must have shape ``[total_q, num_heads, head_dim_qk]``. + Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``. + Must be monotonically non-decreasing with ``cum_seq_lens_q[0] == 0`` and + ``cum_seq_lens_q[-1] == query.size(0)``; ``batch_size`` is inferred as + ``cum_seq_lens_q.size(0) - 1`` and must match ``seq_lens.size(0)``. + Only supported by trtllm-gen backend. + When provided, ``query`` must have shape ``[total_q, num_heads, head_dim_qk]``.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 674 - 678, Update the cum_seq_lens_q docstring to list the exact validation constraints enforced in the code: it must be a 1-D torch.Tensor of dtype torch.int32, length >= 2 and shape [batch_size + 1], the first element must be 0, the last element must equal query.size(0), and the sequence must be monotonically non-decreasing; also note this parameter is only supported by the trtllm-gen backend and that when provided query must have shape [total_q, num_heads, head_dim_qk].
790-804: 💤 Low valueValidation forces a host sync per call.
cum_seq_lens_q.cpu()plus the subsequent.item()calls block the calling stream on every invocation oftrtllm_batch_decode_with_kv_cache_mla. This pattern is consistent withBatchMLAPagedAttentionWrapper.plan()and is acceptable for correctness, but for a hot decode path it does add latency that doesn't exist whencum_seq_lens_q is None.If you want to keep validation but avoid syncing, you can keep the structural checks (dtype/ndim/size) on-device and gate the
[0]/[-1]/monotonicitychecks behind a debug flag (e.g.,__debug__/ env var). Otherwise it's worth a brief comment acknowledging the sync so readers don't try to "optimize" it into a different correctness footprint.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 790 - 804, The validation on cum_seq_lens_q in trtllm_batch_decode_with_kv_cache_mla currently calls cum_seq_lens_q.cpu() and .item() which forces a host sync every call; keep on-device structural checks (dtype, ndim, size) directly on cum_seq_lens_q and move the costly value checks (first/last elements and monotonicity that require .cpu()/.item()) behind a debug gate (e.g., if __debug__ or an env var) or remove them from the hot path, and if you choose to retain the sync keep a clear comment in trtllm_batch_decode_with_kv_cache_mla (and/or document in BatchMLAPagedAttentionWrapper.plan) explaining the intentional host sync so future readers don’t silently rework correctness.tests/attention/test_trtllm_gen_mla.py (2)
323-335: ⚡ Quick winTest gap:
use_cum_seq_lens_q=Trueonly exercises uniform query lengths.
cum_seq_lens_q = arange(0, batch_size + 1) * q_len_per_requestdescribes the same partitioning as the dense path (every sequence has exactlyq_len_per_requesttokens). After the wrapper'squery.flatten(0, 1)for the dense case, the kernel input is essentially identical between the two branches, so this test validates the API plumbing but not the variable-length behavior thatcum_seq_lens_qis meant to enable.Consider adding (or replacing one of the parameter axes with) a case where
q_lensactually vary across the batch — e.g.if use_cum_seq_lens_q: q_lens = [random.randint(1, q_len_per_request) for _ in range(batch_size)] # Build a flat query and reference path that reflect those q_lens. ...Otherwise the doubled parametrization adds CI cost without meaningfully exercising the new code path. (Also worth checking with the matrix that the
[False, True]axis isn't multiplied unnecessarily across all backends —xqa/cute-dslare skipped via the gate at lines 306–307, but the parametrize still spawns those cases.)🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_trtllm_gen_mla.py` around lines 323 - 335, The test currently sets cum_seq_lens_q to arange(0, batch_size + 1) * q_len_per_request which produces uniform query lengths and therefore doesn't exercise variable-length behavior; change the use_cum_seq_lens_q branch to generate variable per-batch q_lens (e.g., q_lens = [randint(1, q_len_per_request) for _ in range(batch_size)]), then build query_input by concatenating per-example query slices into a single flat tensor and compute cum_seq_lens_q as the prefix sum of q_lens (dtype int32) so the kernel receives genuinely variable lengths, and also construct the test's reference/dense-path inputs by flattening according to the same q_lens so comparisons remain valid (update symbols: use_cum_seq_lens_q, query_input, cum_seq_lens_q, q_len_per_request, q_lens, batch_size).
829-829: 💤 Low valueConfirm CI budget for the doubled parametrize matrix.
Combined with the existing
layer_dimensions × batch_size × scale × dtype × page_size × q_len_per_request × dynamic_scale × enable_pdl × backend × skips_softmax × uses_shared_paged_kv_idxmatrix, addinguse_cum_seq_lens_q=[False, True]doubles an already large product. Most of the new cases are skipped for non-trtllm-genbackends via lines 844–845, but pytest still pays collection/skip overhead for each. If CI runtime is a concern, consider scoping this axis to a smaller subset of the matrix (e.g., parametrize onuse_cum_seq_lens_qonly inside a dedicated test with a few representative configs).🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_trtllm_gen_mla.py` at line 829, The new pytest.mark.parametrize("use_cum_seq_lens_q", [False, True]) doubles the already-large test matrix and increases collection/skip overhead; either confirm CI can absorb the extra runtime or reduce scope by moving this parametrize into a smaller dedicated test or by conditionalizing it only for trtllm-gen runs. Update tests/attention/test_trtllm_gen_mla.py to remove or limit the global parametrize on use_cum_seq_lens_q and instead add a focused test (or inner parametrize) that exercises both True/False for a minimal set of representative configurations, or guard the parametrize so it only applies when the backend equals "trtllm-gen" to avoid unnecessary collection overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 760-826: The code currently allows sparse_mla_top_k > 0 to proceed
when cum_seq_lens_q (variable-length queries) is provided; add an explicit
rejection for that combination: after determining has_var_q (the cum_seq_lens_q
checks) or alongside the existing skip_softmax check at the top, raise a
ValueError when sparse_mla_top_k != 0 and has_var_q is True (clear message like
"sparse MLA (sparse_mla_top_k>0) is not supported with variable-length queries
(cum_seq_lens_q) for trtllm-gen"). Also update the docstrings for
sparse_mla_top_k and cum_seq_lens_q to state this restriction and add a unit
test exercising sparse_mla_top_k > 0 with cum_seq_lens_q to ensure the error is
raised; reference symbols: sparse_mla_top_k, cum_seq_lens_q, has_var_q, and
_check_trtllm_gen_mla_shape.
---
Nitpick comments:
In `@flashinfer/mla/_core.py`:
- Around line 674-678: Update the cum_seq_lens_q docstring to list the exact
validation constraints enforced in the code: it must be a 1-D torch.Tensor of
dtype torch.int32, length >= 2 and shape [batch_size + 1], the first element
must be 0, the last element must equal query.size(0), and the sequence must be
monotonically non-decreasing; also note this parameter is only supported by the
trtllm-gen backend and that when provided query must have shape [total_q,
num_heads, head_dim_qk].
- Around line 790-804: The validation on cum_seq_lens_q in
trtllm_batch_decode_with_kv_cache_mla currently calls cum_seq_lens_q.cpu() and
.item() which forces a host sync every call; keep on-device structural checks
(dtype, ndim, size) directly on cum_seq_lens_q and move the costly value checks
(first/last elements and monotonicity that require .cpu()/.item()) behind a
debug gate (e.g., if __debug__ or an env var) or remove them from the hot path,
and if you choose to retain the sync keep a clear comment in
trtllm_batch_decode_with_kv_cache_mla (and/or document in
BatchMLAPagedAttentionWrapper.plan) explaining the intentional host sync so
future readers don’t silently rework correctness.
In `@flashinfer/trace/templates/attention.py`:
- Around line 1357-1367: The dispatch function trtllm_batch_decode_mla_trace
attaches a loose templates attribute for schema discovery which is
fragile/undocumented; either add a brief docstring right above
trtllm_batch_decode_mla_trace explaining that the tracing infra discovers schema
via the templates attribute (and who consumes it), or replace the free function
with a tiny helper wrapper (e.g., a small callable class or named object) that
exposes templates formally so the attribute isn't a side-channel; reference
trtllm_batch_decode_mla_trace and the templates list (containing
trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace)
when making the change.
- Around line 1218-1354: The two MLA trace templates
trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace
declare head_dim_qk, kv_lora_rank and qk_rope_head_dim independently but must
enforce the invariant head_dim_qk == kv_lora_rank + qk_rope_head_dim; update
each TraceTemplate's axes to add a constraints entry (similar to
gqa_paged_decode_trace) that asserts head_dim_qk == kv_lora_rank +
qk_rope_head_dim so consumers can validate inputs consistently with the
reference (see the existing assert in _trtllm_batch_decode_mla_reference).
- Around line 1159-1215: The ragged-path branch when cum_seq_lens_q is not None
never binds q_len, which risks a NameError if the loop logic is refactored;
update the else branch (where batch_size, num_heads, head_dim_qk, and output are
set) to explicitly initialize q_len = None (or q_len = 0) so references to q_len
later (e.g. in the cum_seq_lens_q is None branch and the outer loops using
query/q_batch) remain safe; ensure the change is applied near the cum_seq_lens_q
handling alongside query and output setup.
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 323-335: The test currently sets cum_seq_lens_q to arange(0,
batch_size + 1) * q_len_per_request which produces uniform query lengths and
therefore doesn't exercise variable-length behavior; change the
use_cum_seq_lens_q branch to generate variable per-batch q_lens (e.g., q_lens =
[randint(1, q_len_per_request) for _ in range(batch_size)]), then build
query_input by concatenating per-example query slices into a single flat tensor
and compute cum_seq_lens_q as the prefix sum of q_lens (dtype int32) so the
kernel receives genuinely variable lengths, and also construct the test's
reference/dense-path inputs by flattening according to the same q_lens so
comparisons remain valid (update symbols: use_cum_seq_lens_q, query_input,
cum_seq_lens_q, q_len_per_request, q_lens, batch_size).
- Line 829: The new pytest.mark.parametrize("use_cum_seq_lens_q", [False, True])
doubles the already-large test matrix and increases collection/skip overhead;
either confirm CI can absorb the extra runtime or reduce scope by moving this
parametrize into a smaller dedicated test or by conditionalizing it only for
trtllm-gen runs. Update tests/attention/test_trtllm_gen_mla.py to remove or
limit the global parametrize on use_cum_seq_lens_q and instead add a focused
test (or inner parametrize) that exercises both True/False for a minimal set of
representative configurations, or guard the parametrize so it only applies when
the backend equals "trtllm-gen" to avoid unnecessary collection overhead.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: c6eef880-ac98-4a30-acf6-9d3043a6a303
📥 Commits
Reviewing files that changed from the base of the PR and between 9496956e50ff2820f1161cf122ccb64ddeb67b55 and 752aa4e1fac049e52472380e68e87f483058ff58.
📒 Files selected for processing (4)
flashinfer/mla/_core.pyflashinfer/trace/templates/attention.pytests/attention/test_trtllm_gen_mla.pytests/trace/test_fi_trace.py
There was a problem hiding this comment.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)
718-725:⚠️ Potential issue | 🟡 Minor | ⚡ Quick winReorder XQA guards so
cum_seq_lens_qis rejected before the q_len shape check.When a user passes a 3D
querytogether withcum_seq_lens_qand the resolved backend isxqa,query.size(1)isnum_heads(e.g. 128), so theq_len_per_request == 1guard at lines 718–721 trips first and produces a misleading error referring to MTP query length, instead of the accurate "XQA MLA does not support cum_seq_lens_q" message at lines 724–725.🛡️ Proposed reordering
if sinks is not None: raise ValueError("XQA MLA does not support sinks") + if cum_seq_lens_q is not None: + raise ValueError("XQA MLA does not support cum_seq_lens_q") if query.size(1) != 1: raise ValueError( f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}" ) if skip_softmax_threshold_scale_factor is not None: raise ValueError("skip_softmax is not supported for XQA backend") - if cum_seq_lens_q is not None: - raise ValueError("XQA MLA does not support cum_seq_lens_q")🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 718 - 725, Reorder the XQA MLA input guards so the check for cum_seq_lens_q runs before validating query.size(1): move the "if cum_seq_lens_q is not None: raise ValueError('XQA MLA does not support cum_seq_lens_q')" block to precede the "if query.size(1) != 1: ..." check; keep the skip_softmax_threshold_scale_factor check as-is (or immediately after the cum_seq_lens_q check) so that passing a 3D query together with cum_seq_lens_q triggers the correct "does not support cum_seq_lens_q" error instead of the misleading q_len_per_request error.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)
323-335: ⚡ Quick winConsider also exercising a non‑uniform
cum_seq_lens_q.The
use_cum_seq_lens_q=Truebranch builds a perfectly uniform offset (arange * q_len_per_request), so every per-request length is identical to the dense path. That validates the THD wiring but does not actually exercise the variable-length code path the feature is intended to enable, so a regression that only manifests when per-requestq_lensdiffer would slip through this parametrization. Adding a single test with non-uniform query lengths (e.g. randomq_lensin[1, q_len_per_request]and recomputedcum_seq_lens_q/max_q_len) would close this gap.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@tests/attention/test_trtllm_gen_mla.py` around lines 323 - 335, The current use_cum_seq_lens_q=True branch only tests a uniform cum_seq_lens_q (arange * q_len_per_request); change the test to also exercise a non‑uniform case by generating per-request q_lens (e.g. random ints in [1, q_len_per_request]), compute cum_seq_lens_q via cumulative sum (starting at 0), set max_q_len = q_lens.max(), and then build query_input/reshape and any necessary padding to match the expected (batch_size * variable_requests, num_heads, max_q_len) layout so the variable-length code path is actually exercised; update any assertions that rely on fixed lengths accordingly and keep the original uniform case as well.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 788-810: The validation currently forces a device->host sync via
cum_seq_lens_q.cpu() and .item() calls; add an optional max_q_len parameter
(mirroring flashinfer.decode.trtllm_batch_decode_with_kv_cache) to the function
that owns cum_seq_lens_q and query so that when max_q_len is provided you skip
the cum_seq_lens_q_host = cum_seq_lens_q.cpu() path and related .item() checks,
instead only validate that cum_seq_lens_q and max_q_len are provided together
and use the supplied max_q_len to compute q_lens bounds; still keep the existing
validation when max_q_len is None (preserving behavior), and update error
messages to require both cum_seq_lens_q and max_q_len be supplied together to
avoid host sync.
---
Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 718-725: Reorder the XQA MLA input guards so the check for
cum_seq_lens_q runs before validating query.size(1): move the "if cum_seq_lens_q
is not None: raise ValueError('XQA MLA does not support cum_seq_lens_q')" block
to precede the "if query.size(1) != 1: ..." check; keep the
skip_softmax_threshold_scale_factor check as-is (or immediately after the
cum_seq_lens_q check) so that passing a 3D query together with cum_seq_lens_q
triggers the correct "does not support cum_seq_lens_q" error instead of the
misleading q_len_per_request error.
---
Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 323-335: The current use_cum_seq_lens_q=True branch only tests a
uniform cum_seq_lens_q (arange * q_len_per_request); change the test to also
exercise a non‑uniform case by generating per-request q_lens (e.g. random ints
in [1, q_len_per_request]), compute cum_seq_lens_q via cumulative sum (starting
at 0), set max_q_len = q_lens.max(), and then build query_input/reshape and any
necessary padding to match the expected (batch_size * variable_requests,
num_heads, max_q_len) layout so the variable-length code path is actually
exercised; update any assertions that rely on fixed lengths accordingly and keep
the original uniform case as well.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 68a8e0bf-ec8e-4cfa-8722-4281d8f58f15
📥 Commits
Reviewing files that changed from the base of the PR and between 752aa4e1fac049e52472380e68e87f483058ff58 and e80c0991608b1e1f1d0499dcb79380ad7659e340.
📒 Files selected for processing (4)
flashinfer/mla/_core.pyflashinfer/trace/templates/attention.pytests/attention/test_trtllm_gen_mla.pytests/trace/test_fi_trace.py
e80c099 to
1f39730
Compare
|
/bot run |
There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Line 1128: The regex passed to pytest.raises as the match argument uses regex
metacharacters but is a normal string; update the pytest.raises call (the match=
parameter) to use a raw string literal (prefix with r) so the pattern is
explicit (e.g., change the match value to a raw string like r"...") to satisfy
the RUF043 linting rule and avoid escape-sequence issues.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4240a533-7da8-4cd1-8f27-b1994d5271b4
📥 Commits
Reviewing files that changed from the base of the PR and between e80c0991608b1e1f1d0499dcb79380ad7659e340 and fb63d29add97d7dd4c44f473bfe487c1e32ef34e.
📒 Files selected for processing (4)
flashinfer/mla/_core.pyflashinfer/trace/templates/attention.pytests/attention/test_trtllm_gen_mla.pytests/trace/test_fi_trace.py
🚧 Files skipped from review as they are similar to previous changes (2)
- tests/trace/test_fi_trace.py
- flashinfer/trace/templates/attention.py
53735de to
d3d3562
Compare
|
/bot run |
|
Investigating internal CI failures |
32b4ff7 to
f026afe
Compare
|
/bot run |
…_q support Expose the TRT-LLM Gen MLA query-length metadata that the backend launcher already accepts, while keeping the existing KV seq_lens/max_seq_len contract unchanged. The wrapper now accepts flattened THD query input when cum_seq_lens_q is provided, derives max_q_len internally, and rejects the new argument on unsupported backends. Constraint: TRT-LLM Gen backend ABI already has a cum_seq_lens_q slot and expects max_q_len separately at launch time Rejected: Reuse seq_lens/max_seq_len for query lengths | those arguments describe KV cache lengths for all existing callers Rejected: Expose max_q_len as a public API argument | it is derivable from cum_seq_lens_q and would duplicate caller state Confidence: high Scope-risk: moderate Tested: pre-commit run --files flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py Tested: python3 -m py_compile flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py Tested: git diff --check Tested: Remote SM100 pytest for THD cum_seq_lens_q=True TRT-LLM Gen MLA decode with shared and separate page indices Not-tested: Full TRT-LLM Gen MLA parameter matrix
Split the TRT-LLM Gen MLA trace schema into dense and ragged variants so tracing matches the public API rank in both call modes. Validate query and KV batch metadata before backend module setup, and collapse cum_seq_lens_q scalar checks through one host copy. Constraint: Dense calls still use 4D public query tensors while ragged cum_seq_lens_q calls use flattened 3D query tensors. Rejected: Keep one flattened trace template | dense trace consumers would see the wrong rank contract. Confidence: high Scope-risk: narrow Directive: Keep Python metadata validation before TRT-LLM Gen module lookup so invalid inputs do not trigger JIT setup. Tested: pre-commit on changed files; py_compile; git diff --check; remote SM100 focused pytest for dense/ragged trace schema, batch mismatch validation, and two THD TRT-LLM Gen cases
The ragged TRT-LLM Gen MLA path can avoid synchronizing cum_seq_lens_q to host when the caller supplies max_q_len alongside the query metadata. Keep the previous host-validation fallback for callers that omit max_q_len, but reject max_q_len without cum_seq_lens_q. Constraint: max_q_len mirrors the non-MLA TRT-LLM batch decode API and is only meaningful for the ragged cum_seq_lens_q path. Rejected: Require max_q_len for all cum_seq_lens_q calls | that would break the fallback behavior already added in this PR. Confidence: high Scope-risk: narrow Directive: Keep max_q_len validation on the Python side before TRT-LLM Gen module lookup to avoid JIT setup for invalid metadata. Tested: pre-commit on changed files; py_compile; git diff --check
Address follow-up review comments for TRTLLM Gen MLA ragged query support. Reject unsupported XQA and sparse MLA combinations before shape-sensitive access, document max_q_len bounds, and update tests to cover non-uniform THD query lengths plus overestimated max_q_len. Constraint: cum_seq_lens_q and max_q_len are only supported for trtllm-gen dense MLA paths Rejected: Let sparse MLA accept ragged query metadata | backend support is not established Confidence: high Scope-risk: narrow Tested: pre-commit on changed files; py_compile; git diff --check Tested: SM100 GPU focused tests: 3 passed, 39366 deselected; representative ragged decode case: 1 passed
TRTLLM-gen MLA currently produces invalid output for variable-length query metadata combined with skip-softmax, so reject the public API combination and skip that unsupported parametrized test matrix. Constraint: Pipeline failures showed NaN outputs for use_cum_seq_lens_q=True with skip_softmax=True on SM100/SM103/SM103-class jobs.\nConfidence: high\nScope-risk: narrow\nTested: python -m py_compile flashinfer/mla/_core.py tests/attention/test_trtllm_gen_mla.py; git diff --check; venv pre-commit on touched files; prior focused GPU verification before deleting the redundant reject test\nNot-tested: Full GPU matrix after deleting only the redundant focused reject test
f026afe to
67b3660
Compare
📌 Description
Adds
cum_seq_lens_qsupport to thetrtllm-genpath intrtllm_batch_decode_with_kv_cache_mla.🔍 Related Issues
#3131
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Validation run:
pre-commit run --files flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.pypython3 -m py_compile flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.pygit diff --checktests/attention/test_trtllm_gen_mla.py::test_trtllm_batch_decode_mla[True-True-False-trtllm-gen-False-False-2-64-dtype1-1.0-4-layer_dimensions0]tests/attention/test_trtllm_gen_mla.py::test_trtllm_batch_decode_mla[True-False-False-trtllm-gen-False-False-2-64-dtype1-1.0-4-layer_dimensions0]Reviewer Notes
The new argument is query-side metadata only.
seq_lensandmax_seq_lencontinue to describe KV cache lengths;max_q_lenis derived fromcum_seq_lens_qand is not exposed as a new public API argument.Summary by CodeRabbit
New Features
Tests