[Fix] Fix XQA V tile reading from wrong page when nbVItersPerXIter > 1#3022
[Fix] Fix XQA V tile reading from wrong page when nbVItersPerXIter > 1#3022qsang-nv wants to merge 2 commits into
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughCompute a per-vIter page step for V-cache paging in the MHA kernel and branch on isLastVIter/isLastBeam to advance, rewind, or increment idxPageBeg, always reloading pages after updates. Also allow xqa backend test cases previously skipped for head_dim == 256 to run. ChangesV-cache Page-Loading Logic
Test Skip Adjustment
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 5✅ Passed checks (5 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Code Review
This pull request updates the idxPageBeg calculation logic in csrc/xqa/mha.cu to support multiple V-iterations per X-iteration, replacing a previous assertion that restricted it to a single iteration. Additionally, it enables tests for the XQA backend with a head dimension of 256 by removing a previous skip condition. Feedback suggests refactoring the complex branching logic used for the idxPageBeg update into a helper function to improve code maintainability.
| constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage); | ||
| bool const isLastVIter = (vIter == nbVItersPerXIter - 1); | ||
| bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter)); | ||
| if (isLastVIter && isLastBeam) { | ||
| idxPageBeg += (idxPageBeg % nbPagesPerCtaTile + step_per_viter >= nbPagesPerCtaTile | ||
| ? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step_per_viter | ||
| : step_per_viter); | ||
| loadPages(idxPageBeg); | ||
| } else if (isLastVIter) { | ||
| idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1); | ||
| loadPages(idxPageBeg); | ||
| } else { | ||
| idxPageBeg += step_per_viter; | ||
| loadPages(idxPageBeg); | ||
| } |
|
/bot run |
yzh119
left a comment
There was a problem hiding this comment.
Can you add some test cases where nbVItersPerXIter != 1?
The deleted skipping part(sm120 + headdim 256) in tests/attention/test_trtllm_gen_attention.py makes nbVItersPerXIter != 1, as you can see in the Root Cause part. |
|
/bot run |
saltyminty
left a comment
There was a problem hiding this comment.
Approved pending CI rerun after rebase.
ad0230a to
5766f3e
Compare
|
/bot run |
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/xqa/mha.cu (1)
2252-2266: Add a brief comment explaining the three-way page advancement; minor redundant work whennbVItersPerXIter == 1.The page-advance logic here is correct (math:
nbXItersPerCtaTile * nbVItersPerXIter * step_per_viter = nbPagesPerCtaTile, sostep_per_viterdivides the CTA-tile span and the wrap test in theisLastVIter && isLastBeambranch is exact). However, two small things are worth addressing:
As per coding guidelines (
{include/flashinfer/**/*.cuh,csrc/**/*.cu}: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered"), please add a short comment explaining the three cases — (a) last vIter + last beam: advance bystep_per_viterwith CTA‑tile boundary wrap to the next sub‑seq slice, (b) last vIter + non‑last beam: rewind bystep_per_viter * (nbVItersPerXIter - 1)so the next beam restarts at the original vIter=0 page, (c) otherwise: simple intra‑xIter step. The dense ternary in case (a) and the rewind in case (b) are non-obvious to a future reader.When
nbVItersPerXIter == 1and we hit theelse if (isLastVIter)branch (i.e., beam search on a divergent tile, non‑last beam), the rewind isstep_per_viter * 0 == 0and the subsequentloadPages(idxPageBeg)reissues the same async page load. This is functionally a no-op but adds extraldgststraffic on a path that the previous code did not exercise — slightly contradicting the PR description's "WhennbVItersPerXIter == 1, behavior is unchanged." Consider short-circuiting (or guarding withif constexpr (nbVItersPerXIter > 1)).♻️ Suggested doc + small guard
} else { + // xIterSeqStride > tokensPerPage: each xIter spans multiple pages, and each + // V iteration within an xIter may itself fall on a different page when + // cacheVTileSeqStride >= tokensPerPage. Advance idxPageBeg per V iter: + // - last vIter + last beam : step forward, wrapping past sibling sub-seqs + // when crossing a CTA-tile boundary. + // - last vIter, more beams : rewind to the xIter's first-vIter page so the + // next beam restarts from the same V positions. + // - otherwise : simple intra-xIter forward step. constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage); bool const isLastVIter = (vIter == nbVItersPerXIter - 1); bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter)); if (isLastVIter && isLastBeam) { idxPageBeg += (idxPageBeg % nbPagesPerCtaTile + step_per_viter >= nbPagesPerCtaTile ? nbPagesPerCtaTile * (nbSubSeqPerSeq - 1) + step_per_viter : step_per_viter); loadPages(idxPageBeg); - } else if (isLastVIter) { + } else if (isLastVIter && nbVItersPerXIter > 1) { idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1); loadPages(idxPageBeg); - } else { + } else if (!isLastVIter) { idxPageBeg += step_per_viter; loadPages(idxPageBeg); }As per coding guidelines:
{include/flashinfer/**/*.cuh,csrc/**/*.cu}: "For performance-critical hot paths, leave comments with justification for special algorithmic choices and mention alternative approaches considered".🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@csrc/xqa/mha.cu` around lines 2252 - 2266, Add a short comment above the three-way page-advance block describing the three cases handled: (a) isLastVIter && isLastBeam — advance idxPageBeg by step_per_viter with CTA-tile wrap to move to the next sub-sequence slice, (b) isLastVIter (non-last beam) — rewind idxPageBeg by step_per_viter * (nbVItersPerXIter - 1) so the next beam restarts at vIter=0, and (c) otherwise — advance by step_per_viter within the current xIter; mention the ternary wrap in case (a) and why the rewind is needed. Also prevent the redundant load when nbVItersPerXIter == 1 by guarding the rewind-and-load path (the branch using idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1); loadPages(idxPageBeg);) with a compile-time or runtime check (e.g., if constexpr (nbVItersPerXIter > 1) or an if (nbVItersPerXIter > 1)) so we avoid reissuing loadPages when the multiplier is zero; keep references to step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages, and nbVItersPerXIter.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Nitpick comments:
In `@csrc/xqa/mha.cu`:
- Around line 2252-2266: Add a short comment above the three-way page-advance
block describing the three cases handled: (a) isLastVIter && isLastBeam —
advance idxPageBeg by step_per_viter with CTA-tile wrap to move to the next
sub-sequence slice, (b) isLastVIter (non-last beam) — rewind idxPageBeg by
step_per_viter * (nbVItersPerXIter - 1) so the next beam restarts at vIter=0,
and (c) otherwise — advance by step_per_viter within the current xIter; mention
the ternary wrap in case (a) and why the rewind is needed. Also prevent the
redundant load when nbVItersPerXIter == 1 by guarding the rewind-and-load path
(the branch using idxPageBeg -= step_per_viter * (nbVItersPerXIter - 1);
loadPages(idxPageBeg);) with a compile-time or runtime check (e.g., if constexpr
(nbVItersPerXIter > 1) or an if (nbVItersPerXIter > 1)) so we avoid reissuing
loadPages when the multiplier is zero; keep references to step_per_viter,
isLastVIter, isLastBeam, idxPageBeg, loadPages, and nbVItersPerXIter.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 1f3fb8a0-7011-43ed-b247-ec01ce43213a
📥 Commits
Reviewing files that changed from the base of the PR and between 2af4e7cc6ff726ffc75d699544b0279a2f963022 and 5766f3e3c703d07551e4695101635e9fadb95c5b.
📒 Files selected for processing (2)
csrc/xqa/mha.cutests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
- tests/attention/test_trtllm_gen_attention.py
5766f3e to
b8861fe
Compare
a9323fb to
2482404
Compare
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/xqa/mha.cu (1)
2252-2266: 💤 Low valueOptional: add a short explanatory comment for the three branches.
The branching is non-obvious (especially the rewind for the not-last-beam case, which mirrors the per-beam restart semantics of
loadPages). A 3-line comment summarizing the intent of each branch would help future readers and pairs naturally with the root-cause description in the PR.📝 Suggested comment
} else { constexpr auto step_per_viter = exactDiv(cacheVTileSeqStride, tokensPerPage); bool const isLastVIter = (vIter == nbVItersPerXIter - 1); bool const isLastBeam = (idxBeam == beamWidth - 1 || isConvergedTile(seqIter)); + // nbVItersPerXIter > 1 may straddle multiple pages within one warp X tile, so + // page state must be advanced per V iteration: + // - last vIter + last beam: advance to next xIter, with CTA-tile wrap for multi-block. + // - last vIter + more beams remaining: rewind to vIter==0 page so next beam restarts. + // - intermediate vIter: advance by one V-iter step. if (isLastVIter && isLastBeam) {🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/xqa/mha.cu` around lines 2252 - 2266, Add a short 2–3 line comment immediately above the branching that explains the three cases for updating idxPageBeg before calling loadPages: (1) when isLastVIter && isLastBeam — advance by step_per_viter but wrap to the next CTA tile across subsequences using nbPagesPerCtaTile and nbSubSeqPerSeq, (2) when isLastVIter only — rewind by step_per_viter*(nbVItersPerXIter-1) to restart per-beam page sequence, and (3) otherwise — advance by step_per_viter for the next v-iteration; reference step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages, nbVItersPerXIter, nbPagesPerCtaTile, and nbSubSeqPerSeq in the comment so future readers see the intent and the relation to loadPages semantics.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@csrc/xqa/mha.cu`:
- Around line 2252-2266: Add a short 2–3 line comment immediately above the
branching that explains the three cases for updating idxPageBeg before calling
loadPages: (1) when isLastVIter && isLastBeam — advance by step_per_viter but
wrap to the next CTA tile across subsequences using nbPagesPerCtaTile and
nbSubSeqPerSeq, (2) when isLastVIter only — rewind by
step_per_viter*(nbVItersPerXIter-1) to restart per-beam page sequence, and (3)
otherwise — advance by step_per_viter for the next v-iteration; reference
step_per_viter, isLastVIter, isLastBeam, idxPageBeg, loadPages,
nbVItersPerXIter, nbPagesPerCtaTile, and nbSubSeqPerSeq in the comment so future
readers see the intent and the relation to loadPages semantics.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: cc5644d9-df0c-445d-98c7-253924e965d5
📥 Commits
Reviewing files that changed from the base of the PR and between a9323fb6360cc5c2fa0e874eee32f9db41c25659 and 2482404900ad2308939e0ee5991cfa4b61035cd3.
📒 Files selected for processing (2)
csrc/xqa/mha.cutests/attention/test_trtllm_gen_attention.py
💤 Files with no reviewable changes (1)
- tests/attention/test_trtllm_gen_attention.py
Signed-off-by: Qidi Sang <200703406+qsang-nv@users.noreply.github.com>
2482404 to
f9f4680
Compare
|
/bot run |
📌 Description
Summary
Fix incorrect XQA attention results on architectures with
cacheVTileSeqLen = 32(SM120/SM121) whenhead_dim = 256andpage_size < 64.Bug
On SM120 with
head_dim=256andpage_sizeof 16 or 32, the XQAmha.cukernel produces incorrect attention outputs. Allpage_size=64cases pass; allpage_size < 64cases fail. The bug is independent of dtype,kv_layout,batch_size,window_left, and other parameters. SM90 and SM100 are unaffected.Root Cause
In
mha.cu, the V tile page advancement logic (loadVTilePartlambda) has two branches based onxIterSeqStridevstokensPerPage. Theelsebranch (whenxIterSeqStride > tokensPerPage) assumesnbVItersPerXIter == 1, meaning each warp X tile (64 tokens) is covered by a single V tile load. This assumption is violated under a specific combination of compile-time constants:cacheVTileSeqLen = 32(SM120/SM121)head_dim = 256→gemm1WarpsPerGrp = 4,gemm1NbWarpGrps = 1cacheVTileSeqStride = 32 × 1 = 32 < warpTile.x = 64nbVItersPerXIter = 2With
nbVItersPerXIter = 2, each warp X tile requires two V iterations (vIter=0 covering tokens [0, 32) and vIter=1 covering tokens [32, 64)). Whenpage_size < 64, these two V iterations land on different pages. However,loadPages()was only called after the last V iteration (vIter == nbVItersPerXIter - 1), leaving the page index stale for vIter=1 — it reads KV cache data from the wrong page.On SM90/SM100 (
cacheVTileSeqLen = 64),cacheVTileSeqStride = 64andnbVItersPerXIter = 1, so the bug never triggers. On SM120 withhead_dim = 128,gemm1NbWarpGrps = 2makescacheVTileSeqStride = 64, also avoiding the issue.Fix
Replace the single page advancement at the end of each X iteration with per-V-iteration page advancement:
vIter < nbVItersPerXIter - 1): advanceidxPageBegbystep_per_viter = cacheVTileSeqStride / tokensPerPageand reload pages.isLastVIter && isLastBeam): advance with CTA-tile boundary wrapping (multi-block mode), same as original but usingstep_per_viter.isLastVIter && !isLastBeam): resetidxPageBegbackward so the next beam restarts fromvIter=0's page position.When
nbVItersPerXIter == 1, only the first branch fires and the behavior is identical to the original code — no performance or correctness impact on existing working paths.Test Plan
test_trtllm_batch_decodewithbackend=xqa,head_dim=256, all page sizes (16, 32, 64) on SM120test_trtllm_batch_decodewithbackend=xqa,head_dim=128, all page sizes on SM120 (regression check)test_trtllm_batch_decodewithbackend=xqa,head_dim=256, all page sizes on SM90 (regression check)🔍 Related Issues
🚀 Pull Request Checklist
Thank you for contributing to FlashInfer! Before we review your pull request, please make sure the following items are complete.
✅ Pre-commit Checks
pre-commitby runningpip install pre-commit(or used your preferred method).pre-commit install.pre-commit run --all-filesand fixed any reported issues.🧪 Tests
unittest, etc.).Reviewer Notes
Summary by CodeRabbit
Bug Fixes
Tests