Skip to content

[Fix] Fix XQA V tile reading from wrong page when nbVItersPerXIter > 1#3022

Open
qsang-nv wants to merge 2 commits into
flashinfer-ai:mainfrom
qsang-nv:fix_xqa_sm120_headdim_256
Open

[Fix] Fix XQA V tile reading from wrong page when nbVItersPerXIter > 1#3022
qsang-nv wants to merge 2 commits into
flashinfer-ai:mainfrom
qsang-nv:fix_xqa_sm120_headdim_256

Conversation

@qsang-nv
Copy link
Copy Markdown
Collaborator

@qsang-nv qsang-nv commented Apr 9, 2026

📌 Description

Summary

Fix incorrect XQA attention results on architectures with cacheVTileSeqLen = 32 (SM120/SM121) when head_dim = 256 and page_size < 64.

Bug

On SM120 with head_dim=256 and page_size of 16 or 32, the XQA mha.cu kernel produces incorrect attention outputs. All page_size=64 cases pass; all page_size < 64 cases fail. The bug is independent of dtype, kv_layout, batch_size, window_left, and other parameters. SM90 and SM100 are unaffected.

Root Cause

In mha.cu, the V tile page advancement logic (loadVTilePart lambda) has two branches based on xIterSeqStride vs tokensPerPage. The else branch (when xIterSeqStride > tokensPerPage) assumes nbVItersPerXIter == 1, meaning each warp X tile (64 tokens) is covered by a single V tile load. This assumption is violated under a specific combination of compile-time constants:

  • cacheVTileSeqLen = 32 (SM120/SM121)
  • head_dim = 256gemm1WarpsPerGrp = 4, gemm1NbWarpGrps = 1
  • cacheVTileSeqStride = 32 × 1 = 32 < warpTile.x = 64
  • nbVItersPerXIter = 2
    With nbVItersPerXIter = 2, each warp X tile requires two V iterations (vIter=0 covering tokens [0, 32) and vIter=1 covering tokens [32, 64)). When page_size < 64, these two V iterations land on different pages. However, loadPages() was only called after the last V iteration (vIter == nbVItersPerXIter - 1), leaving the page index stale for vIter=1 — it reads KV cache data from the wrong page.
    On SM90/SM100 (cacheVTileSeqLen = 64), cacheVTileSeqStride = 64 and nbVItersPerXIter = 1, so the bug never triggers. On SM120 with head_dim = 128, gemm1NbWarpGrps = 2 makes cacheVTileSeqStride = 64, also avoiding the issue.

Fix

Replace the single page advancement at the end of each X iteration with per-V-iteration page advancement:

  • Intermediate V iteration (vIter < nbVItersPerXIter - 1): advance idxPageBeg by step_per_viter = cacheVTileSeqStride / tokensPerPage and reload pages.
  • Last V iteration, last beam (isLastVIter && isLastBeam): advance with CTA-tile boundary wrapping (multi-block mode), same as original but using step_per_viter.
  • Last V iteration, not last beam (isLastVIter && !isLastBeam): reset idxPageBeg backward so the next beam restarts from vIter=0's page position.
    When nbVItersPerXIter == 1, only the first branch fires and the behavior is identical to the original code — no performance or correctness impact on existing working paths.

Test Plan

  • test_trtllm_batch_decode with backend=xqa, head_dim=256, all page sizes (16, 32, 64) on SM120
  • test_trtllm_batch_decode with backend=xqa, head_dim=128, all page sizes on SM120 (regression check)
  • test_trtllm_batch_decode with backend=xqa, head_dim=256, all page sizes on SM90 (regression check)

🔍 Related Issues

🚀 Pull Request Checklist

Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.

✅ Pre-commit Checks

  • I have installed pre-commit by running pip install pre-commit (or used your preferred method).
  • I have installed the hooks with pre-commit install.
  • I have run the hooks manually with pre-commit run --all-files and fixed any reported issues.

If you are unsure about how to set up pre-commit, see the pre-commit documentation.

🧪 Tests

  • Tests have been added or updated as needed.
  • All tests are passing (unittest, etc.).

Reviewer Notes

Summary by CodeRabbit

  • Bug Fixes

    • Improved attention page-loading to correctly handle multi-iteration stride scenarios, including proper advance, rewind and reload behavior across iteration and beam boundaries for more reliable caching.
  • Tests

    • Re-enabled tests for configurations with head_dim == 256 that were previously skipped to verify multi-iteration page-loading.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented Apr 9, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

Compute a per-vIter page step for V-cache paging in the MHA kernel and branch on isLastVIter/isLastBeam to advance, rewind, or increment idxPageBeg, always reloading pages after updates. Also allow xqa backend test cases previously skipped for head_dim == 256 to run.

Changes

V-cache Page-Loading Logic

Layer / File(s) Summary
Data Shape / Constants
csrc/xqa/mha.cu
Introduce step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage) (per-vIter page step).
Core Control Flow
csrc/xqa/mha.cu
Remove single-vIter assumption and branch on isLastVIter and isLastBeam to: (a) perform end-of-range advance with subsequence wrap/reposition and load, (b) rewind by multiple step_per_viter when vIter is last but beam isn't, or (c) increment by one step_per_viter for non-terminal vIter.
I/O / Page Loading
csrc/xqa/mha.cu
Always call loadPages(idxPageBeg) after computing/updating idxPageBeg for each branch.

Test Skip Adjustment

Layer / File(s) Summary
Test Logic
tests/attention/test_trtllm_gen_attention.py
Deleted the pytest.skip that excluded xqa backend cases when head_dim == 256; those configurations now run the shared decode test path.
Test Execution
tests/attention/test_trtllm_gen_attention.py
No other test files changed; head_dim==256 xqa coverage is now exercised by the existing test call path.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • bkryu
  • nv-yunzheq
  • saltyminty
  • yzh119

Poem

🐰 I hopped through tiles with a careful chart,

step_per_viter set, every page to start.
Rewind or wrap, then load the view anew,
Dim‑256 wakes up and joins the queue,
Rabbit nibbles code and hums a kernel tune.

🚥 Pre-merge checks | ✅ 5
✅ Passed checks (5 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically identifies the fix: XQA V tile page reading bug when nbVItersPerXIter > 1, which matches the core issue addressed in the changeset.
Description check ✅ Passed The description is comprehensive and includes all critical sections: a clear Summary, detailed Bug explanation, Root Cause analysis, Fix description, and Test Plan with specific verification cases.
Docstring Coverage ✅ Passed No functions found in the changed files to evaluate docstring coverage. Skipping docstring coverage check.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the idxPageBeg calculation logic in csrc/xqa/mha.cu to support multiple V-iterations per X-iteration, replacing a previous assertion that restricted it to a single iteration. Additionally, it enables tests for the XQA backend with a head dimension of 256 by removing a previous skip condition. Feedback suggests refactoring the complex branching logic used for the idxPageBeg update into a helper function to improve code maintainability.

Comment thread csrc/xqa/mha.cu
Comment on lines +2252 to 2266
constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage);
bool const isLastVIter = (vIter == nbVItersPerXIter - 1);
bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter));
if (isLastVIter && isLastBeam) {
idxPageBeg += (idxPageBeg % nbPagesPerCtaTile + step_per_viter >= nbPagesPerCtaTile
? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step_per_viter
: step_per_viter);
loadPages(idxPageBeg);
} else if (isLastVIter) {
idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1);
loadPages(idxPageBeg);
} else {
idxPageBeg += step_per_viter;
loadPages(idxPageBeg);
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for idxPageBeg update is complex and involves multiple branches. Consider extracting this into a helper function or a more readable structure to improve maintainability, as this is a critical part of the V-tile loading logic.

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

qsang-nv commented Apr 9, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !528 has been created, and the CI pipeline #48089802 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Collaborator

@yzh119 yzh119 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add some test cases where nbVItersPerXIter != 1?

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

qsang-nv commented Apr 9, 2026

Can you add some test cases where nbVItersPerXIter != 1?

The deleted skipping part(sm120 + headdim 256) in tests/attention/test_trtllm_gen_attention.py makes nbVItersPerXIter != 1, as you can see in the Root Cause part.

@qsang-nv
Copy link
Copy Markdown
Collaborator Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !528 has been updated with latest changes, and the CI pipeline #48946604 is currently running. I'll report back once the pipeline job completes.

@yzh119 yzh119 added the run-ci label Apr 21, 2026
Copy link
Copy Markdown
Collaborator

@saltyminty saltyminty left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approved pending CI rerun after rebase.

@saltyminty saltyminty force-pushed the fix_xqa_sm120_headdim_256 branch from ad0230a to 5766f3e Compare April 25, 2026 00:02
@saltyminty
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !528 has been updated with latest changes, and the CI pipeline #49445339 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/xqa/mha.cu (1)

2252-2266: Add a brief comment explaining the three-way page advancement; minor redundant work when nbVItersPerXIter == 1.

The page-advance logic here is correct (math: nbXItersPerCtaTile * nbVItersPerXIter * step_per_viter = nbPagesPerCtaTile, so step_per_viter divides the CTA-tile span and the wrap test in the isLastVIter && isLastBeam branch is exact). However, two small things are worth addressing:

  1. As per coding guidelines ({include/flashinfer/**/*.cuh,csrc/**/*.cu}: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered"), please add a short comment explaining the three cases — (a) last vIter + last beam: advance by step_per_viter with CTA‑tile boundary wrap to the next sub‑seq slice, (b) last vIter + non‑last beam: rewind by step_per_viter * (nbVItersPerXIter - 1) so the next beam restarts at the original vIter=0 page, (c) otherwise: simple intra‑xIter step. The dense ternary in case (a) and the rewind in case (b) are non-obvious to a future reader.

  2. When nbVItersPerXIter == 1 and we hit the else if (isLastVIter) branch (i.e., beam search on a divergent tile, non‑last beam), the rewind is step_per_viter * 0 == 0 and the subsequent loadPages(idxPageBeg) reissues the same async page load. This is functionally a no-op but adds extra ldgsts traffic on a path that the previous code did not exercise — slightly contradicting the PR description's "When nbVItersPerXIter == 1, behavior is unchanged." Consider short-circuiting (or guarding with if constexpr (nbVItersPerXIter > 1)).

♻️ Suggested doc + small guard
       } else {
+        // xIterSeqStride > tokensPerPage: each xIter spans multiple pages, and each
+        // V iteration within an xIter may itself fall on a different page when
+        // cacheVTileSeqStride >= tokensPerPage. Advance idxPageBeg per V iter:
+        //  - last vIter + last beam : step forward, wrapping past sibling sub-seqs
+        //                             when crossing a CTA-tile boundary.
+        //  - last vIter, more beams : rewind to the xIter's first-vIter page so the
+        //                             next beam restarts from the same V positions.
+        //  - otherwise              : simple intra-xIter forward step.
         constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage);
         bool const isLastVIter = (vIter == nbVItersPerXIter - 1);
         bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter));
         if (isLastVIter && isLastBeam) {
           idxPageBeg += (idxPageBeg % nbPagesPerCtaTile + step_per_viter >= nbPagesPerCtaTile
                              ? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step_per_viter
                              : step_per_viter);
           loadPages(idxPageBeg);
-        } else if (isLastVIter) {
+        } else if (isLastVIter && nbVItersPerXIter > 1) {
           idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1);
           loadPages(idxPageBeg);
-        } else {
+        } else if (!isLastVIter) {
           idxPageBeg += step_per_viter;
           loadPages(idxPageBeg);
         }

As per coding guidelines: {include/flashinfer/**/*.cuh,csrc/**/*.cu}: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".

🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.

In `@csrc/xqa/mha.cu` around lines 2252 - 2266, Add a short comment above the
three-way page-advance block describing the three cases handled: (a) isLastVIter
&& isLastBeam — advance idxPageBeg by step_per_viter with CTA-tile wrap to move
to the next sub-sequence slice, (b) isLastVIter (non-last beam) — rewind
idxPageBeg by step_per_viter * (nbVItersPerXIter - 1) so the next beam restarts
at vIter=0, and (c) otherwise — advance by step_per_viter within the current
xIter; mention the ternary wrap in case (a) and why the rewind is needed. Also
prevent the redundant load when nbVItersPerXIter == 1 by guarding the
rewind-and-load path (the branch using idxPageBeg -= step_per_viter *
(nbVItersPerXIter - 1); loadPages(idxPageBeg);) with a compile-time or runtime
check (e.g., if constexpr (nbVItersPerXIter > 1) or an if (nbVItersPerXIter >
1)) so we avoid reissuing loadPages when the multiplier is zero; keep references
to step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages, and
nbVItersPerXIter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.

Nitpick comments:
In `@csrc/xqa/mha.cu`:
- Around line 2252-2266: Add a short comment above the three-way page-advance
block describing the three cases handled: (a) isLastVIter && isLastBeam —
advance idxPageBeg by step_per_viter with CTA-tile wrap to move to the next
sub-sequence slice, (b) isLastVIter (non-last beam) — rewind idxPageBeg by
step_per_viter * (nbVItersPerXIter - 1) so the next beam restarts at vIter=0,
and (c) otherwise — advance by step_per_viter within the current xIter; mention
the ternary wrap in case (a) and why the rewind is needed. Also prevent the
redundant load when nbVItersPerXIter == 1 by guarding the rewind-and-load path
(the branch using idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1);
loadPages(idxPageBeg);) with a compile-time or runtime check (e.g., if constexpr
(nbVItersPerXIter > 1) or an if (nbVItersPerXIter > 1)) so we avoid reissuing
loadPages when the multiplier is zero; keep references to step_per_viter,
isLastVIter, isLastBeam, idxPageBeg, loadPages, and nbVItersPerXIter.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 1f3fb8a0-7011-43ed-b247-ec01ce43213a

📥 Commits

Reviewing files that changed from the base of the PR and between 2af4e7cc6ff726ffc75d699544b0279a2f963022 and 5766f3e3c703d07551e4695101635e9fadb95c5b.

📒 Files selected for processing (2)
  • csrc/xqa/mha.cu
  • tests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
  • tests/attention/test_trtllm_gen_attention.py

@saltyminty saltyminty force-pushed the fix_xqa_sm120_headdim_256 branch from 5766f3e to b8861fe Compare April 27, 2026 18:19
@saltyminty saltyminty force-pushed the fix_xqa_sm120_headdim_256 branch 2 times, most recently from a9323fb to 2482404 Compare May 6, 2026 16:51
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/xqa/mha.cu (1)

2252-2266: 💤 Low value

Optional: add a short explanatory comment for the three branches.

The branching is non-obvious (especially the rewind for the not-last-beam case, which mirrors the per-beam restart semantics of loadPages). A 3-line comment summarizing the intent of each branch would help future readers and pairs naturally with the root-cause description in the PR.

📝 Suggested comment
       } else {
         constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage);
         bool const isLastVIter = (vIter == nbVItersPerXIter - 1);
         bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter));
+        // nbVItersPerXIter > 1 may straddle multiple pages within one warp X tile, so
+        // page state must be advanced per V iteration:
+        //   - last vIter + last beam: advance to next xIter, with CTA-tile wrap for multi-block.
+        //   - last vIter + more beams remaining: rewind to vIter==0 page so next beam restarts.
+        //   - intermediate vIter: advance by one V-iter step.
         if (isLastVIter && isLastBeam) {
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/xqa/mha.cu` around lines 2252 - 2266, Add a short 2–3 line comment
immediately above the branching that explains the three cases for updating
idxPageBeg before calling loadPages: (1) when isLastVIter && isLastBeam —
advance by step_per_viter but wrap to the next CTA tile across subsequences
using nbPagesPerCtaTile and nbSubSeqPerSeq, (2) when isLastVIter only — rewind
by step_per_viter*(nbVItersPerXIter-1) to restart per-beam page sequence, and
(3) otherwise — advance by step_per_viter for the next v-iteration; reference
step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages,
nbVItersPerXIter, nbPagesPerCtaTile, and nbSubSeqPerSeq in the comment so future
readers see the intent and the relation to loadPages semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@csrc/xqa/mha.cu`:
- Around line 2252-2266: Add a short 2–3 line comment immediately above the
branching that explains the three cases for updating idxPageBeg before calling
loadPages: (1) when isLastVIter && isLastBeam — advance by step_per_viter but
wrap to the next CTA tile across subsequences using nbPagesPerCtaTile and
nbSubSeqPerSeq, (2) when isLastVIter only — rewind by
step_per_viter*(nbVItersPerXIter-1) to restart per-beam page sequence, and (3)
otherwise — advance by step_per_viter for the next v-iteration; reference
step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages,
nbVItersPerXIter, nbPagesPerCtaTile, and nbSubSeqPerSeq in the comment so future
readers see the intent and the relation to loadPages semantics.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: cc5644d9-df0c-445d-98c7-253924e965d5

📥 Commits

Reviewing files that changed from the base of the PR and between a9323fb6360cc5c2fa0e874eee32f9db41c25659 and 2482404900ad2308939e0ee5991cfa4b61035cd3.

📒 Files selected for processing (2)
  • csrc/xqa/mha.cu
  • tests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
  • tests/attention/test_trtllm_gen_attention.py

Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
@saltyminty saltyminty force-pushed the fix_xqa_sm120_headdim_256 branch from 2482404 to f9f4680 Compare May 7, 2026 18:28
@samuellees
Copy link
Copy Markdown
Collaborator

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !528 has been updated with latest changes, and the CI pipeline #50985783 is currently running. I'll report back once the pipeline job completes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants