Skip to content

trtllm_batch_decode_with_kv_cache_mla trtllm-gen backend cum_seq_lens_q support#3238

Open
saltyminty wants to merge 5 commits into
mainfrom
fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g
Open

trtllm_batch_decode_with_kv_cache_mla trtllm-gen backend cum_seq_lens_q support#3238
saltyminty wants to merge 5 commits into
mainfrom
fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g

Conversation

@saltyminty
Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty commented May 6, 2026

📌 Description

Adds cum_seq_lens_q support to the trtllm-gen path in trtllm_batch_decode_with_kv_cache_mla.

🔍 Related Issues

#3131

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Validation run:

  • pre-commit run --files flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py
  • python3 -m py_compile flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py
  • git diff --check
  • Remote SM100 GPU pytest in FlashInfer CI container:
    • tests/attention/test_trtllm_gen_mla.py::test_trtllm_batch_decode_mla[True-True-False-trtllm-gen-False-False-2-64-dtype1-1.0-4-layer_dimensions0]
    • tests/attention/test_trtllm_gen_mla.py::test_trtllm_batch_decode_mla[True-False-False-trtllm-gen-False-False-2-64-dtype1-1.0-4-layer_dimensions0]

Reviewer Notes

The new argument is query-side metadata only. seq_lens and max_seq_len continue to describe KV cache lengths; max_q_len is derived from cum_seq_lens_q and is not exposed as a new public API argument.

Summary by CodeRabbit

  • New Features

    • Support for variable-length per-request (ragged) MLA queries in batch decode, with optional cumulated-sequence-lengths and max-query-length handling and runtime dispatch between dense and ragged traces.
    • New public trace variants for dense and ragged per-token MLA decodes and a dispatching trace entrypoint.
  • Tests

    • Added and extended tests covering dense/ragged decode paths, cumulated-sequence-lengths workflows, validation, and mismatch/error cases.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 6, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Adds optional ragged (per-request variable-length) query support via cum_seq_lens_q/max_q_len, extends 3D-query validation, forwards cumulated metadata into the TRT‑LLM MLA kernel, adds dense/ragged trace templates, and updates tests to cover ragged workflows.

Changes

Variable-Length MLA Decode (single cohort)

Layer / File(s) Summary
API Signature & Doc
flashinfer/mla/_core.py
trtllm_batch_decode_with_kv_cache_mla now accepts cum_seq_lens_q: Optional[torch.Tensor] = None and max_q_len: Optional[int] = None; docs note backend/shape constraints.
Shape Validation
flashinfer/mla/_core.py
_check_trtllm_gen_mla_shape extended to accept batch_size and max_q_len and handle 3D query as flattened ragged input.
Decode Logic & Wiring
flashinfer/mla/_core.py
Detects has_var_q from cum_seq_lens_q, validates cumulated lengths and backend constraints, derives max_q_len when omitted, avoids flattening for ragged queries, and threads cum_seq_lens_q into the TRT‑LLM kernel call.
Kernel Invocation
flashinfer/mla/_core.py
Passes cum_seq_lens_q through to trtllm_paged_attention_decode instead of a hardcoded None.
Trace Templates (dense & ragged)
flashinfer/trace/templates/attention.py
Adds trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace, a dispatcher trtllm_batch_decode_mla_trace, and per-token ragged vs per-request dense reference behavior.
Tests & Parametrization
tests/attention/test_trtllm_gen_mla.py, tests/trace/test_fi_trace.py
Parametrized tests to run with/without cum_seq_lens_q, build/validate cum_seq_lens_q and merged query inputs, add batch-mismatch and guard tests, and validate fi_trace dense/ragged definitions.

Sequence Diagram(s)

sequenceDiagram
    participant Test as Test/Client
    participant MLA as flashinfer.mla/_core
    participant Trace as trace/templates/attention
    participant Kernel as TRT-LLM Kernel

    Test->>MLA: call trtllm_batch_decode_with_kv_cache_mla(query, cum_seq_lens_q?)
    MLA->>MLA: validate shapes, compute max_q_len / has_var_q
    MLA->>Trace: select dense or ragged trace based on cum_seq_lens_q
    MLA->>Kernel: invoke decode kernel (pass cum_seq_lens_q when ragged)
    Kernel-->>MLA: per-token or per-batch MLA outputs
    MLA-->>Test: assemble and return outputs
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yzh119
  • cyx-6
  • bkryu
  • nv-yunzheq
  • samuellees
  • jimmyzho

Poem

🐰 I hop through ragged tokens, counting each queue,

cum_seq_lens_q in paw — lengths stitched into view.
From 4D down to 3D I tuck and keep things neat,
Per-token traces march so outputs align and meet.
Hooray — dense and ragged now both dance to the beat.

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 25.00% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title accurately describes the main change: adding cum_seq_lens_q support to the trtllm-gen backend path in trtllm_batch_decode_with_kv_cache_mla.
Description check ✅ Passed The description follows the template structure with Description, Related Issues, and Pre-commit/Test checklists completed. All required sections are present and sufficiently filled out.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Commit unit tests in branch fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for variable-length (ragged) queries in the trtllm_batch_decode_with_kv_cache_mla function by adding a cum_seq_lens_q parameter. The changes include updated shape validation logic, reference implementation adjustments for ragged tensors, and expanded test coverage. Review feedback highlights a potential rank mismatch in the trace template for dense calls, suggests optimizing host-device synchronizations during tensor validation to improve performance, and recommends adding a consistency check between the query batch size and the provided sequence lengths.

Comment thread flashinfer/trace/templates/attention.py
Comment thread flashinfer/mla/_core.py Outdated
Comment thread flashinfer/mla/_core.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/trace/templates/attention.py (1)

1245-1287: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Keep this trace template aligned with the public MLA API, not only the flattened backend form.

trtllm_batch_decode_with_kv_cache_mla() still accepts dense 4D query and returns dense 4D output when cum_seq_lens_q is omitted (flashinfer/mla/_core.py, Lines 798-837), but this template now hardcodes flattened 3D query/output shapes and also models workspace_buffer as ["num_pages"]. That makes the existing non-ragged call shape either untraceable or misleading. Please represent both public forms here, or split the ragged backend form into a separate template instead of redefining the existing one.

As per coding guidelines, flashinfer/trace/templates/**/*.py: Add a TraceTemplate in flashinfer/trace/templates/ for new or updated APIs to enable trace functionality.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/trace/templates/attention.py` around lines 1245 - 1287, The
template currently only models the flattened backend ragged shapes for
"query"/"output" and models "workspace_buffer" as ["num_pages"], which breaks or
misrepresents the public dense 4D API used by
trtllm_batch_decode_with_kv_cache_mla (see MLA caller code that omits
cum_seq_lens_q); update the trace templates to either (a) represent both public
dense 4D forms and the ragged flattened 3D forms in this template (accepting
["batch_size","q_len","num_heads","head_dim_qk"] and
["num_tokens","num_heads","head_dim_qk"] for "query", and matching shapes for
"output"), and model "workspace_buffer" as the correct shape(s) used by the
public API, or (b) split into two TraceTemplate entries (one for the public
dense API and one for the backend ragged API) and register both; ensure the
template names/TraceTemplate declarations match the API
(trtllm_batch_decode_with_kv_cache_mla) and include cum_seq_lens_q as optional
to disambiguate paths, following the flashinfer/trace/templates/* pattern so
trace functionality is enabled.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

495-508: ⚡ Quick win

This assertion still only validates the uniform-length case.

output.reshape(batch_size, q_len_per_request, ...) assumes every request contributes exactly q_len_per_request tokens, so use_cum_seq_lens_q=True is currently just checking a flattened dense batch. A launcher bug that ignores cum_seq_lens_q and keeps using fixed q_len_per_request would still pass here. Please switch this comparison to cum_seq_lens_q-driven slices and add at least one uneven offset pattern.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 495 - 508, The test
currently reshapes output and o_ref using q_len_per_request which only validates
uniform-length batches; change the comparison to iterate requests using
cum_seq_lens_q to compute per-request lengths (e.g., lengths =
cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]) and slice output and o_ref along the
sequence dimension using those offsets instead of output.reshape; for each
request, compare the sliced tensors with torch.testing.assert_close using the
existing rtol/atol logic. Also modify the test input to include at least one
non-uniform cum_seq_lens_q pattern (uneven offsets) so the uneven-length case is
exercised, and keep references to symbols output, o_ref, cum_seq_lens_q, and
q_len_per_request when updating the assertions.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@flashinfer/trace/templates/attention.py`:
- Around line 1245-1287: The template currently only models the flattened
backend ragged shapes for "query"/"output" and models "workspace_buffer" as
["num_pages"], which breaks or misrepresents the public dense 4D API used by
trtllm_batch_decode_with_kv_cache_mla (see MLA caller code that omits
cum_seq_lens_q); update the trace templates to either (a) represent both public
dense 4D forms and the ragged flattened 3D forms in this template (accepting
["batch_size","q_len","num_heads","head_dim_qk"] and
["num_tokens","num_heads","head_dim_qk"] for "query", and matching shapes for
"output"), and model "workspace_buffer" as the correct shape(s) used by the
public API, or (b) split into two TraceTemplate entries (one for the public
dense API and one for the backend ragged API) and register both; ensure the
template names/TraceTemplate declarations match the API
(trtllm_batch_decode_with_kv_cache_mla) and include cum_seq_lens_q as optional
to disambiguate paths, following the flashinfer/trace/templates/* pattern so
trace functionality is enabled.

---

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 495-508: The test currently reshapes output and o_ref using
q_len_per_request which only validates uniform-length batches; change the
comparison to iterate requests using cum_seq_lens_q to compute per-request
lengths (e.g., lengths = cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]) and slice
output and o_ref along the sequence dimension using those offsets instead of
output.reshape; for each request, compare the sliced tensors with
torch.testing.assert_close using the existing rtol/atol logic. Also modify the
test input to include at least one non-uniform cum_seq_lens_q pattern (uneven
offsets) so the uneven-length case is exercised, and keep references to symbols
output, o_ref, cum_seq_lens_q, and q_len_per_request when updating the
assertions.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 79e6ad90-8e57-4125-be7f-4440e5c8723f

📥 Commits

Reviewing files that changed from the base of the PR and between ba30d4f and 9496956e50ff2820f1161cf122ccb64ddeb67b55.

📒 Files selected for processing (3)
  • flashinfer/mla/_core.py
  • flashinfer/trace/templates/attention.py
  • tests/attention/test_trtllm_gen_mla.py

@saltyminty saltyminty force-pushed the fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g branch 2 times, most recently from 752aa4e to e80c099 Compare May 6, 2026 00:33
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)

760-826: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Add explicit rejection or documentation for sparse_mla_top_k with variable-length queries.

The trtllm-gen backend does not explicitly reject sparse_mla_top_k > 0 when cum_seq_lens_q is provided, nor does it document support for this combination. Currently:

  • Tests exercise cum_seq_lens_q without sparse mode, and sparse mode without cum_seq_lens_q (separate paths)
  • Docstrings for both parameters lack cross-references or restrictions
  • No validation prevents the combination from reaching the kernel

Either:

  1. Explicitly reject this combination alongside other backend-specific constraints (e.g., line 760–761 for skip_softmax + sparse), or
  2. Document the supported page_table layout and add a test exercising sparse + variable-length together
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 760 - 826, The code currently allows
sparse_mla_top_k > 0 to proceed when cum_seq_lens_q (variable-length queries) is
provided; add an explicit rejection for that combination: after determining
has_var_q (the cum_seq_lens_q checks) or alongside the existing skip_softmax
check at the top, raise a ValueError when sparse_mla_top_k != 0 and has_var_q is
True (clear message like "sparse MLA (sparse_mla_top_k>0) is not supported with
variable-length queries (cum_seq_lens_q) for trtllm-gen"). Also update the
docstrings for sparse_mla_top_k and cum_seq_lens_q to state this restriction and
add a unit test exercising sparse_mla_top_k > 0 with cum_seq_lens_q to ensure
the error is raised; reference symbols: sparse_mla_top_k, cum_seq_lens_q,
has_var_q, and _check_trtllm_gen_mla_shape.
🧹 Nitpick comments (7)
flashinfer/trace/templates/attention.py (3)

1357-1367: 💤 Low value

Dispatch wrapper looks good; consider documenting the templates attribute contract.

trtllm_batch_decode_mla_trace is a callable selector rather than a TraceTemplate instance, with templates attached as a side-channel attribute (type: ignore[attr-defined]). If the tracing infrastructure relies on this attribute for schema discovery (looks that way given the fi_trace test inspecting both names), it would help to either (a) leave a short docstring comment about why the list is attached and who consumes it, or (b) introduce a tiny helper class so the attribute isn't loose. Not a blocker — it works as-is.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/trace/templates/attention.py` around lines 1357 - 1367, The
dispatch function trtllm_batch_decode_mla_trace attaches a loose templates
attribute for schema discovery which is fragile/undocumented; either add a brief
docstring right above trtllm_batch_decode_mla_trace explaining that the tracing
infra discovers schema via the templates attribute (and who consumes it), or
replace the free function with a tiny helper wrapper (e.g., a small callable
class or named object) that exposes templates formally so the attribute isn't a
side-channel; reference trtllm_batch_decode_mla_trace and the templates list
(containing trtllm_batch_decode_mla_dense_trace and
trtllm_batch_decode_mla_ragged_trace) when making the change.

1218-1354: 💤 Low value

Consider adding a head_dim_qk == kv_lora_rank + qk_rope_head_dim constraint to both MLA traces.

The reference implementation enforces this with assert head_dim_qk == kv_lora_rank + qk_rope_head_dim (Line 1175), but the templates themselves declare head_dim_qk, kv_lora_rank, and qk_rope_head_dim as independent Const axes with no constraints=[...] linking them. Other MLA-style templates in this file (e.g., gqa_paged_decode_trace) use the constraints field to express such invariants, which helps consumers of the trace schema validate inputs without re-deriving the relationship.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/trace/templates/attention.py` around lines 1218 - 1354, The two
MLA trace templates trtllm_batch_decode_mla_dense_trace and
trtllm_batch_decode_mla_ragged_trace declare head_dim_qk, kv_lora_rank and
qk_rope_head_dim independently but must enforce the invariant head_dim_qk ==
kv_lora_rank + qk_rope_head_dim; update each TraceTemplate's axes to add a
constraints entry (similar to gqa_paged_decode_trace) that asserts head_dim_qk
== kv_lora_rank + qk_rope_head_dim so consumers can validate inputs consistently
with the reference (see the existing assert in
_trtllm_batch_decode_mla_reference).

1159-1215: 💤 Low value

Reference impl branches look correct; minor robustness note.

The cum_seq_lens_q branching cleanly handles both the dense (4D query) and ragged (3D query) layouts, with output shape derived from query.shape[:-1] + (kv_lora_rank,) in the ragged case. One small thing worth noting: in the ragged branch q_len is never bound, so referencing it would be a NameError; this is currently only used in the cum_seq_lens_q is None branch (line 1199), so it's safe but fragile if someone later refactors the inner loop. A defensive q_len = None initialization in the else branch would prevent a future regression.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/trace/templates/attention.py` around lines 1159 - 1215, The
ragged-path branch when cum_seq_lens_q is not None never binds q_len, which
risks a NameError if the loop logic is refactored; update the else branch (where
batch_size, num_heads, head_dim_qk, and output are set) to explicitly initialize
q_len = None (or q_len = 0) so references to q_len later (e.g. in the
cum_seq_lens_q is None branch and the outer loops using query/q_batch) remain
safe; ensure the change is applied near the cum_seq_lens_q handling alongside
query and output setup.
flashinfer/mla/_core.py (2)

674-678: ⚡ Quick win

Document additional constraints on cum_seq_lens_q.

The validation block enforces that cum_seq_lens_q is 1D int32, has at least two entries, starts with 0, ends at query.size(0), and is monotonically non-decreasing. Surfacing these requirements in the docstring would save users a debugging round-trip.

📝 Suggested docstring update
     cum_seq_lens_q : Optional[torch.Tensor] = None
-        Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``.
-        Only supported by trtllm-gen backend.
-        When provided, ``query`` must have shape ``[total_q, num_heads, head_dim_qk]``.
+        Cumulative query sequence lengths for variable-length query support, shape: ``[batch_size + 1]``, dtype: ``torch.int32``.
+        Must be monotonically non-decreasing with ``cum_seq_lens_q[0] == 0`` and
+        ``cum_seq_lens_q[-1] == query.size(0)``; ``batch_size`` is inferred as
+        ``cum_seq_lens_q.size(0) - 1`` and must match ``seq_lens.size(0)``.
+        Only supported by trtllm-gen backend.
+        When provided, ``query`` must have shape ``[total_q, num_heads, head_dim_qk]``.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 674 - 678, Update the cum_seq_lens_q
docstring to list the exact validation constraints enforced in the code: it must
be a 1-D torch.Tensor of dtype torch.int32, length >= 2 and shape [batch_size +
1], the first element must be 0, the last element must equal query.size(0), and
the sequence must be monotonically non-decreasing; also note this parameter is
only supported by the trtllm-gen backend and that when provided query must have
shape [total_q, num_heads, head_dim_qk].

790-804: 💤 Low value

Validation forces a host sync per call.

cum_seq_lens_q.cpu() plus the subsequent .item() calls block the calling stream on every invocation of trtllm_batch_decode_with_kv_cache_mla. This pattern is consistent with BatchMLAPagedAttentionWrapper.plan() and is acceptable for correctness, but for a hot decode path it does add latency that doesn't exist when cum_seq_lens_q is None.

If you want to keep validation but avoid syncing, you can keep the structural checks (dtype/ndim/size) on-device and gate the [0]/[-1]/monotonicity checks behind a debug flag (e.g., __debug__ / env var). Otherwise it's worth a brief comment acknowledging the sync so readers don't try to "optimize" it into a different correctness footprint.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 790 - 804, The validation on
cum_seq_lens_q in trtllm_batch_decode_with_kv_cache_mla currently calls
cum_seq_lens_q.cpu() and .item() which forces a host sync every call; keep
on-device structural checks (dtype, ndim, size) directly on cum_seq_lens_q and
move the costly value checks (first/last elements and monotonicity that require
.cpu()/.item()) behind a debug gate (e.g., if __debug__ or an env var) or remove
them from the hot path, and if you choose to retain the sync keep a clear
comment in trtllm_batch_decode_with_kv_cache_mla (and/or document in
BatchMLAPagedAttentionWrapper.plan) explaining the intentional host sync so
future readers don’t silently rework correctness.
tests/attention/test_trtllm_gen_mla.py (2)

323-335: ⚡ Quick win

Test gap: use_cum_seq_lens_q=True only exercises uniform query lengths.

cum_seq_lens_q = arange(0, batch_size + 1) * q_len_per_request describes the same partitioning as the dense path (every sequence has exactly q_len_per_request tokens). After the wrapper's query.flatten(0, 1) for the dense case, the kernel input is essentially identical between the two branches, so this test validates the API plumbing but not the variable-length behavior that cum_seq_lens_q is meant to enable.

Consider adding (or replacing one of the parameter axes with) a case where q_lens actually vary across the batch — e.g.

if use_cum_seq_lens_q:
    q_lens = [random.randint(1, q_len_per_request) for _ in range(batch_size)]
    # Build a flat query and reference path that reflect those q_lens.
    ...

Otherwise the doubled parametrization adds CI cost without meaningfully exercising the new code path. (Also worth checking with the matrix that the [False, True] axis isn't multiplied unnecessarily across all backends — xqa/cute-dsl are skipped via the gate at lines 306–307, but the parametrize still spawns those cases.)

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 323 - 335, The test
currently sets cum_seq_lens_q to arange(0, batch_size + 1) * q_len_per_request
which produces uniform query lengths and therefore doesn't exercise
variable-length behavior; change the use_cum_seq_lens_q branch to generate
variable per-batch q_lens (e.g., q_lens = [randint(1, q_len_per_request) for _
in range(batch_size)]), then build query_input by concatenating per-example
query slices into a single flat tensor and compute cum_seq_lens_q as the prefix
sum of q_lens (dtype int32) so the kernel receives genuinely variable lengths,
and also construct the test's reference/dense-path inputs by flattening
according to the same q_lens so comparisons remain valid (update symbols:
use_cum_seq_lens_q, query_input, cum_seq_lens_q, q_len_per_request, q_lens,
batch_size).

829-829: 💤 Low value

Confirm CI budget for the doubled parametrize matrix.

Combined with the existing layer_dimensions × batch_size × scale × dtype × page_size × q_len_per_request × dynamic_scale × enable_pdl × backend × skips_softmax × uses_shared_paged_kv_idx matrix, adding use_cum_seq_lens_q=[False, True] doubles an already large product. Most of the new cases are skipped for non-trtllm-gen backends via lines 844–845, but pytest still pays collection/skip overhead for each. If CI runtime is a concern, consider scoping this axis to a smaller subset of the matrix (e.g., parametrize on use_cum_seq_lens_q only inside a dedicated test with a few representative configs).

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_trtllm_gen_mla.py` at line 829, The new
pytest.mark.parametrize("use_cum_seq_lens_q", [False, True]) doubles the
already-large test matrix and increases collection/skip overhead; either confirm
CI can absorb the extra runtime or reduce scope by moving this parametrize into
a smaller dedicated test or by conditionalizing it only for trtllm-gen runs.
Update tests/attention/test_trtllm_gen_mla.py to remove or limit the global
parametrize on use_cum_seq_lens_q and instead add a focused test (or inner
parametrize) that exercises both True/False for a minimal set of representative
configurations, or guard the parametrize so it only applies when the backend
equals "trtllm-gen" to avoid unnecessary collection overhead.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 760-826: The code currently allows sparse_mla_top_k > 0 to proceed
when cum_seq_lens_q (variable-length queries) is provided; add an explicit
rejection for that combination: after determining has_var_q (the cum_seq_lens_q
checks) or alongside the existing skip_softmax check at the top, raise a
ValueError when sparse_mla_top_k != 0 and has_var_q is True (clear message like
"sparse MLA (sparse_mla_top_k>0) is not supported with variable-length queries
(cum_seq_lens_q) for trtllm-gen"). Also update the docstrings for
sparse_mla_top_k and cum_seq_lens_q to state this restriction and add a unit
test exercising sparse_mla_top_k > 0 with cum_seq_lens_q to ensure the error is
raised; reference symbols: sparse_mla_top_k, cum_seq_lens_q, has_var_q, and
_check_trtllm_gen_mla_shape.

---

Nitpick comments:
In `@flashinfer/mla/_core.py`:
- Around line 674-678: Update the cum_seq_lens_q docstring to list the exact
validation constraints enforced in the code: it must be a 1-D torch.Tensor of
dtype torch.int32, length >= 2 and shape [batch_size + 1], the first element
must be 0, the last element must equal query.size(0), and the sequence must be
monotonically non-decreasing; also note this parameter is only supported by the
trtllm-gen backend and that when provided query must have shape [total_q,
num_heads, head_dim_qk].
- Around line 790-804: The validation on cum_seq_lens_q in
trtllm_batch_decode_with_kv_cache_mla currently calls cum_seq_lens_q.cpu() and
.item() which forces a host sync every call; keep on-device structural checks
(dtype, ndim, size) directly on cum_seq_lens_q and move the costly value checks
(first/last elements and monotonicity that require .cpu()/.item()) behind a
debug gate (e.g., if __debug__ or an env var) or remove them from the hot path,
and if you choose to retain the sync keep a clear comment in
trtllm_batch_decode_with_kv_cache_mla (and/or document in
BatchMLAPagedAttentionWrapper.plan) explaining the intentional host sync so
future readers don’t silently rework correctness.

In `@flashinfer/trace/templates/attention.py`:
- Around line 1357-1367: The dispatch function trtllm_batch_decode_mla_trace
attaches a loose templates attribute for schema discovery which is
fragile/undocumented; either add a brief docstring right above
trtllm_batch_decode_mla_trace explaining that the tracing infra discovers schema
via the templates attribute (and who consumes it), or replace the free function
with a tiny helper wrapper (e.g., a small callable class or named object) that
exposes templates formally so the attribute isn't a side-channel; reference
trtllm_batch_decode_mla_trace and the templates list (containing
trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace)
when making the change.
- Around line 1218-1354: The two MLA trace templates
trtllm_batch_decode_mla_dense_trace and trtllm_batch_decode_mla_ragged_trace
declare head_dim_qk, kv_lora_rank and qk_rope_head_dim independently but must
enforce the invariant head_dim_qk == kv_lora_rank + qk_rope_head_dim; update
each TraceTemplate's axes to add a constraints entry (similar to
gqa_paged_decode_trace) that asserts head_dim_qk == kv_lora_rank +
qk_rope_head_dim so consumers can validate inputs consistently with the
reference (see the existing assert in _trtllm_batch_decode_mla_reference).
- Around line 1159-1215: The ragged-path branch when cum_seq_lens_q is not None
never binds q_len, which risks a NameError if the loop logic is refactored;
update the else branch (where batch_size, num_heads, head_dim_qk, and output are
set) to explicitly initialize q_len = None (or q_len = 0) so references to q_len
later (e.g. in the cum_seq_lens_q is None branch and the outer loops using
query/q_batch) remain safe; ensure the change is applied near the cum_seq_lens_q
handling alongside query and output setup.

In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 323-335: The test currently sets cum_seq_lens_q to arange(0,
batch_size + 1) * q_len_per_request which produces uniform query lengths and
therefore doesn't exercise variable-length behavior; change the
use_cum_seq_lens_q branch to generate variable per-batch q_lens (e.g., q_lens =
[randint(1, q_len_per_request) for _ in range(batch_size)]), then build
query_input by concatenating per-example query slices into a single flat tensor
and compute cum_seq_lens_q as the prefix sum of q_lens (dtype int32) so the
kernel receives genuinely variable lengths, and also construct the test's
reference/dense-path inputs by flattening according to the same q_lens so
comparisons remain valid (update symbols: use_cum_seq_lens_q, query_input,
cum_seq_lens_q, q_len_per_request, q_lens, batch_size).
- Line 829: The new pytest.mark.parametrize("use_cum_seq_lens_q", [False, True])
doubles the already-large test matrix and increases collection/skip overhead;
either confirm CI can absorb the extra runtime or reduce scope by moving this
parametrize into a smaller dedicated test or by conditionalizing it only for
trtllm-gen runs. Update tests/attention/test_trtllm_gen_mla.py to remove or
limit the global parametrize on use_cum_seq_lens_q and instead add a focused
test (or inner parametrize) that exercises both True/False for a minimal set of
representative configurations, or guard the parametrize so it only applies when
the backend equals "trtllm-gen" to avoid unnecessary collection overhead.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: c6eef880-ac98-4a30-acf6-9d3043a6a303

📥 Commits

Reviewing files that changed from the base of the PR and between 9496956e50ff2820f1161cf122ccb64ddeb67b55 and 752aa4e1fac049e52472380e68e87f483058ff58.

📒 Files selected for processing (4)
  • flashinfer/mla/_core.py
  • flashinfer/trace/templates/attention.py
  • tests/attention/test_trtllm_gen_mla.py
  • tests/trace/test_fi_trace.py

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
flashinfer/mla/_core.py (1)

718-725: ⚠️ Potential issue | 🟡 Minor | ⚡ Quick win

Reorder XQA guards so cum_seq_lens_q is rejected before the q_len shape check.

When a user passes a 3D query together with cum_seq_lens_q and the resolved backend is xqa, query.size(1) is num_heads (e.g. 128), so the q_len_per_request == 1 guard at lines 718–721 trips first and produces a misleading error referring to MTP query length, instead of the accurate "XQA MLA does not support cum_seq_lens_q" message at lines 724–725.

🛡️ Proposed reordering
         if sinks is not None:
             raise ValueError("XQA MLA does not support sinks")
+        if cum_seq_lens_q is not None:
+            raise ValueError("XQA MLA does not support cum_seq_lens_q")
         if query.size(1) != 1:
             raise ValueError(
                 f"XQA MLA only supports q_len_per_request == 1, got {query.size(1)}"
             )
         if skip_softmax_threshold_scale_factor is not None:
             raise ValueError("skip_softmax is not supported for XQA backend")
-        if cum_seq_lens_q is not None:
-            raise ValueError("XQA MLA does not support cum_seq_lens_q")
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 718 - 725, Reorder the XQA MLA input
guards so the check for cum_seq_lens_q runs before validating query.size(1):
move the "if cum_seq_lens_q is not None: raise ValueError('XQA MLA does not
support cum_seq_lens_q')" block to precede the "if query.size(1) != 1: ..."
check; keep the skip_softmax_threshold_scale_factor check as-is (or immediately
after the cum_seq_lens_q check) so that passing a 3D query together with
cum_seq_lens_q triggers the correct "does not support cum_seq_lens_q" error
instead of the misleading q_len_per_request error.
🧹 Nitpick comments (1)
tests/attention/test_trtllm_gen_mla.py (1)

323-335: ⚡ Quick win

Consider also exercising a non‑uniform cum_seq_lens_q.

The use_cum_seq_lens_q=True branch builds a perfectly uniform offset (arange * q_len_per_request), so every per-request length is identical to the dense path. That validates the THD wiring but does not actually exercise the variable-length code path the feature is intended to enable, so a regression that only manifests when per-request q_lens differ would slip through this parametrization. Adding a single test with non-uniform query lengths (e.g. random q_lens in [1, q_len_per_request] and recomputed cum_seq_lens_q/max_q_len) would close this gap.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/attention/test_trtllm_gen_mla.py` around lines 323 - 335, The current
use_cum_seq_lens_q=True branch only tests a uniform cum_seq_lens_q (arange *
q_len_per_request); change the test to also exercise a non‑uniform case by
generating per-request q_lens (e.g. random ints in [1, q_len_per_request]),
compute cum_seq_lens_q via cumulative sum (starting at 0), set max_q_len =
q_lens.max(), and then build query_input/reshape and any necessary padding to
match the expected (batch_size * variable_requests, num_heads, max_q_len) layout
so the variable-length code path is actually exercised; update any assertions
that rely on fixed lengths accordingly and keep the original uniform case as
well.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 788-810: The validation currently forces a device->host sync via
cum_seq_lens_q.cpu() and .item() calls; add an optional max_q_len parameter
(mirroring flashinfer.decode.trtllm_batch_decode_with_kv_cache) to the function
that owns cum_seq_lens_q and query so that when max_q_len is provided you skip
the cum_seq_lens_q_host = cum_seq_lens_q.cpu() path and related .item() checks,
instead only validate that cum_seq_lens_q and max_q_len are provided together
and use the supplied max_q_len to compute q_lens bounds; still keep the existing
validation when max_q_len is None (preserving behavior), and update error
messages to require both cum_seq_lens_q and max_q_len be supplied together to
avoid host sync.

---

Outside diff comments:
In `@flashinfer/mla/_core.py`:
- Around line 718-725: Reorder the XQA MLA input guards so the check for
cum_seq_lens_q runs before validating query.size(1): move the "if cum_seq_lens_q
is not None: raise ValueError('XQA MLA does not support cum_seq_lens_q')" block
to precede the "if query.size(1) != 1: ..." check; keep the
skip_softmax_threshold_scale_factor check as-is (or immediately after the
cum_seq_lens_q check) so that passing a 3D query together with cum_seq_lens_q
triggers the correct "does not support cum_seq_lens_q" error instead of the
misleading q_len_per_request error.

---

Nitpick comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Around line 323-335: The current use_cum_seq_lens_q=True branch only tests a
uniform cum_seq_lens_q (arange * q_len_per_request); change the test to also
exercise a non‑uniform case by generating per-request q_lens (e.g. random ints
in [1, q_len_per_request]), compute cum_seq_lens_q via cumulative sum (starting
at 0), set max_q_len = q_lens.max(), and then build query_input/reshape and any
necessary padding to match the expected (batch_size * variable_requests,
num_heads, max_q_len) layout so the variable-length code path is actually
exercised; update any assertions that rely on fixed lengths accordingly and keep
the original uniform case as well.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 68a8e0bf-ec8e-4cfa-8722-4281d8f58f15

📥 Commits

Reviewing files that changed from the base of the PR and between 752aa4e1fac049e52472380e68e87f483058ff58 and e80c0991608b1e1f1d0499dcb79380ad7659e340.

📒 Files selected for processing (4)
  • flashinfer/mla/_core.py
  • flashinfer/trace/templates/attention.py
  • tests/attention/test_trtllm_gen_mla.py
  • tests/trace/test_fi_trace.py

Comment thread flashinfer/mla/_core.py Outdated
@saltyminty saltyminty force-pushed the fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g branch from e80c099 to 1f39730 Compare May 6, 2026 16:36
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !635 has been created, and the CI pipeline #50449084 is currently running. I'll report back once the pipeline job completes.

Comment thread flashinfer/mla/_core.py Outdated
Comment thread tests/attention/test_trtllm_gen_mla.py
Comment thread flashinfer/mla/_core.py
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@tests/attention/test_trtllm_gen_mla.py`:
- Line 1128: The regex passed to pytest.raises as the match argument uses regex
metacharacters but is a normal string; update the pytest.raises call (the match=
parameter) to use a raw string literal (prefix with r) so the pattern is
explicit (e.g., change the match value to a raw string like r"...") to satisfy
the RUF043 linting rule and avoid escape-sequence issues.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4240a533-7da8-4cd1-8f27-b1994d5271b4

📥 Commits

Reviewing files that changed from the base of the PR and between e80c0991608b1e1f1d0499dcb79380ad7659e340 and fb63d29add97d7dd4c44f473bfe487c1e32ef34e.

📒 Files selected for processing (4)
  • flashinfer/mla/_core.py
  • flashinfer/trace/templates/attention.py
  • tests/attention/test_trtllm_gen_mla.py
  • tests/trace/test_fi_trace.py
🚧 Files skipped from review as they are similar to previous changes (2)
  • tests/trace/test_fi_trace.py
  • flashinfer/trace/templates/attention.py

Comment thread tests/attention/test_trtllm_gen_mla.py Outdated
@saltyminty saltyminty force-pushed the fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g branch 2 times, most recently from 53735de to d3d3562 Compare May 7, 2026 21:40
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !635 has been updated with latest changes, and the CI pipeline #50602487 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@saltyminty
Copy link
Copy Markdown
Collaborator Author

saltyminty commented May 8, 2026

Investigating internal CI failures
Update: use_cum_seq_lens_q is not compatible with skip_softmax

@saltyminty saltyminty force-pushed the fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g branch from 32b4ff7 to f026afe Compare May 8, 2026 23:57
@saltyminty
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !635 has been updated with latest changes, and the CI pipeline #50733198 is currently running. I'll report back once the pipeline job completes.

…_q support

Expose the TRT-LLM Gen MLA query-length metadata that the backend launcher already accepts, while keeping the existing KV seq_lens/max_seq_len contract unchanged. The wrapper now accepts flattened THD query input when cum_seq_lens_q is provided, derives max_q_len internally, and rejects the new argument on unsupported backends.

Constraint: TRT-LLM Gen backend ABI already has a cum_seq_lens_q slot and expects max_q_len separately at launch time
Rejected: Reuse seq_lens/max_seq_len for query lengths | those arguments describe KV cache lengths for all existing callers
Rejected: Expose max_q_len as a public API argument | it is derivable from cum_seq_lens_q and would duplicate caller state
Confidence: high
Scope-risk: moderate
Tested: pre-commit run --files flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py
Tested: python3 -m py_compile flashinfer/mla/_core.py flashinfer/trace/templates/attention.py tests/attention/test_trtllm_gen_mla.py
Tested: git diff --check
Tested: Remote SM100 pytest for THD cum_seq_lens_q=True TRT-LLM Gen MLA decode with shared and separate page indices
Not-tested: Full TRT-LLM Gen MLA parameter matrix
Split the TRT-LLM Gen MLA trace schema into dense and ragged variants so tracing matches the public API rank in both call modes. Validate query and KV batch metadata before backend module setup, and collapse cum_seq_lens_q scalar checks through one host copy.

Constraint: Dense calls still use 4D public query tensors while ragged cum_seq_lens_q calls use flattened 3D query tensors.

Rejected: Keep one flattened trace template | dense trace consumers would see the wrong rank contract.

Confidence: high

Scope-risk: narrow

Directive: Keep Python metadata validation before TRT-LLM Gen module lookup so invalid inputs do not trigger JIT setup.

Tested: pre-commit on changed files; py_compile; git diff --check; remote SM100 focused pytest for dense/ragged trace schema, batch mismatch validation, and two THD TRT-LLM Gen cases
The ragged TRT-LLM Gen MLA path can avoid synchronizing cum_seq_lens_q to host when the caller supplies max_q_len alongside the query metadata. Keep the previous host-validation fallback for callers that omit max_q_len, but reject max_q_len without cum_seq_lens_q.

Constraint: max_q_len mirrors the non-MLA TRT-LLM batch decode API and is only meaningful for the ragged cum_seq_lens_q path.

Rejected: Require max_q_len for all cum_seq_lens_q calls | that would break the fallback behavior already added in this PR.

Confidence: high

Scope-risk: narrow

Directive: Keep max_q_len validation on the Python side before TRT-LLM Gen module lookup to avoid JIT setup for invalid metadata.

Tested: pre-commit on changed files; py_compile; git diff --check
Address follow-up review comments for TRTLLM Gen MLA ragged query support. Reject unsupported XQA and sparse MLA combinations before shape-sensitive access, document max_q_len bounds, and update tests to cover non-uniform THD query lengths plus overestimated max_q_len.

Constraint: cum_seq_lens_q and max_q_len are only supported for trtllm-gen dense MLA paths

Rejected: Let sparse MLA accept ragged query metadata | backend support is not established

Confidence: high

Scope-risk: narrow

Tested: pre-commit on changed files; py_compile; git diff --check

Tested: SM100 GPU focused tests: 3 passed, 39366 deselected; representative ragged decode case: 1 passed
TRTLLM-gen MLA currently produces invalid output for variable-length query metadata combined with skip-softmax, so reject the public API combination and skip that unsupported parametrized test matrix.

Constraint: Pipeline failures showed NaN outputs for use_cum_seq_lens_q=True with skip_softmax=True on SM100/SM103/SM103-class jobs.\nConfidence: high\nScope-risk: narrow\nTested: python -m py_compile flashinfer/mla/_core.py tests/attention/test_trtllm_gen_mla.py; git diff --check; venv pre-commit on touched files; prior focused GPU verification before deleting the redundant reject test\nNot-tested: Full GPU matrix after deleting only the redundant focused reject test
@saltyminty saltyminty force-pushed the fix/mingyangw/support-cum-seq-lens-q-in-trtllm-batch-decode-with-kv-cache-mla-trtllm-g branch from f026afe to 67b3660 Compare May 11, 2026 16:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants