Skip to content

[draft] copy of Fix barrier deadlock in fmha_v2 fp8+head_dim=256 transpose_v_tile#3276

Open
jimmyzho wants to merge 4 commits into
flashinfer-ai:mainfrom
jimmyzho:fmhav2-fp8
Open

[draft] copy of Fix barrier deadlock in fmha_v2 fp8+head_dim=256 transpose_v_tile#3276
jimmyzho wants to merge 4 commits into
flashinfer-ai:mainfrom
jimmyzho:fmhav2-fp8

Conversation

@jimmyzho
Copy link
Copy Markdown
Contributor

@jimmyzho jimmyzho commented May 9, 2026

Forked and modified from @bobboli's Original PR: #2957

Root cause

On Hopper (SM90), the warp-specialized FP8 FMHAv2 kernel runs a dedicated DMA warpgroup that explicitly transposes the V tile (via STSM) from smem_v_scratch into smem_v before BMM2, because FP8 QGMMA can't accept a transposed B-operand. The V-scratch FIFO between TMA load and STSM transpose was hardcoded to depth 1 (V_SCRATCH_BUFFERS = DMA_GROUP_TRANSPOSE_V ? 1 : 0).

At depth 1, producer (_wptr) and consumer (_rptr) share one physical slot and the phase bit alone distinguishes epochs N and N+1. Any momentary ordering inversion between the 128 transposer threads and DMA thread 0 corrupts the phase and deadlocks at the next bar.sync — observable as a hang in FP8 prefill at head_dim=256 after enabling FP8 output coverage.

The fix is to raise the V-scratch FIFO to depth KV_BUFFERS (≥ 2 for FP8 head_size ≤ 128). Raising the depth unmasks two latent bugs that depth-1 was hiding, and exposes a third race the kernel didn't previously have a window for. All three need to land together for KV_BUFFERS > 1 to be correct, and a fourth fix is needed for the persistent scheduler on the FP8 path.

Changes

kernel_traits.h — raise V-scratch FIFO depth from 1 to KV_BUFFERS. Also tighten the V-transposer UNROLL gate from STEP_KV > 128 to STEP_KV >= 128 to avoid register spilling at the deeper buffer state.

dma.h — three coupled fixes the new depth requires:

  • smem_v_scratch[v_scratch_barrier_id]smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V]. At depth 1 v_scratch_barrier_id is always 0 so the missing stride was dormant; at depth > 1 it would corrupt the transposed V output.
  • Thread the captured v_scratch_barrier_id into cbr_v_scratch.peek/wait/pop() instead of letting the reader use its internal _rptr. Required because _rptr walks independently of the slot the writer reserved when the V tile was loaded.
  • Insert named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP) immediately after threadReserve() in transpose_v_tile. Without it, DMA thread 0 can race ahead into iteration N+1 (advancing the consumed-barrier phase) while warps 1–3 still poll iteration N's expected phase, deadlocking the warpgroup. Canonical leader-ahead fix, same pattern push_with_sync uses.

circular_buffer.h — add ptr-taking overloads of peek/wait/advance/pop so callers can wait on a specific slot rather than the reader's internal _rptr. Required by the dma.h fix above.

FP8 persistent scheduler for ragged q-tiles — decode_exact_dynamic_tile_id walks cu_q_seqlens and computes each batch element's q-tile count on the fly. The old scheduler assumed a uniform num_tiles_per_head across the batch and relied on downstream skip-tile logic to ignore invalid tiles. That worked for FP16/BF16, but on the FP8 path the DMA warpgroup does real work for those ghost tiles (TMA + transpose) and the barrier machinery doesn't unwind cleanly. Gated on DMA_GROUP_TRANSPOSE_V so FP16/BF16 keeps its tuned scheduling.

FP8-output coverage + smem budget — fmha_library.py now drops kv_tile_buffers to 1 for FP8-output head_dim=256 (FP8→FP8 adds two output staging buffers, pushing past H100's 228KB cap). Depth-1 is safe at this config because the new named_barrier_wait after threadReserve keeps the DMA warpgroup synchronized across iterations.

Test

  • Enables FP8 output prefill coverage in tests/attention/test_fmha_v2_prefill.py.
  • Removes the FP8 sliding-window test markers that were masking the original hang.

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 9, 2026

📝 Walkthrough

Walkthrough

CircularBufferReader gains pointer-parameterized peek/wait/advance/pop operations to support flexible barrier-specific consumption. V-transpose scratch buffers now size to KV_BUFFERS depth. DMA scheduling introduces exact dynamic tile-id decoding to map tile indices to precise batch/head/Q-step positions. V-transpose synchronization uses the new pointer API for barrier-id-specific operations. FP8 support is enabled through output dtype configuration and removal of runtime validation restrictions.

Changes

FMHA v2 V-Transpose Pipeline and FP8 Support

Layer / File(s) Summary
Circular Buffer Reader Pointer-Parameterized Operations
csrc/fmha_v2/fmha/warpspec/circular_buffer.h
CircularBufferReader interface extended with peek(int ptr), wait(int ptr), advance(int ptr), and pop(int tid0, int ptr) overloads for explicit pointer-based barrier synchronization; existing no-argument methods delegate to new overloads using internal _rptr.
V-Transpose Scratch Buffer Allocation
csrc/fmha_v2/fmha/warpspec/kernel_traits.h
V_SCRATCH_BUFFERS constant updated to use KV_BUFFERS (double-buffer depth) instead of fixed single buffer when DMA_GROUP_TRANSPOSE_V is enabled, providing staging capacity for circular buffer operations.
Dynamic Tile-ID Decoding Functions
csrc/fmha_v2/fmha/warpspec/dma.h
New Device helper methods compute_dynamic_q_tiles_per_head() and decode_exact_dynamic_tile_id() provide precise mapping from tile_id to batch index, head index, and Q-step offset by iterating per-batch sequences.
GMMA Transposer Unroll Threshold Adjustment
csrc/fmha_v2/fmha/warpspec/dma.h
Transposer K-parameter unroll condition refined from STEP_KV > 128 to STEP_KV >= 128.
DMA Scheduling with Exact Tile Decoding
csrc/fmha_v2/fmha/warpspec/dma.h
Both run_packed_qkv and run_separate_q_and_kv scheduling paths invoke decode_exact_dynamic_tile_id when DMA_GROUP_TRANSPOSE_V is enabled, deriving deterministic Q-tile allocation; separate path computes local_q_tile_offset from decoded q_step_offset.
V-Transpose Barrier-ID Synchronization
csrc/fmha_v2/fmha/warpspec/dma.h
transpose_v_tile() refactored to use v_scratch_barrier_id for pointer-specific circular buffer operations, adds explicit inter-DMA-thread synchronization after destination V-barrier reservation, and pops scratch ring using explicit barrier_id.
FP8 Effective Output Dtype and Conditional Buffering
flashinfer/jit/attention/fmha_v2/fmha_library.py
generate_kernel_spec introduces effective_output_dtype tracking; SM90 warp-specialized FP8 path conditionally sets kv_tile_buffers = 1 when effective output is FP8 e4m3 to fit H100 shared-memory budget.
FP8 Runtime Validation and Test Enablement
flashinfer/prefill.py, tests/attention/test_fmha_v2_prefill.py
Removes FP8 (e4m3) sliding-window runtime validation for specific configurations; removes pytest.skip guard from test execution for torch.float8_e4m3fn parameterization.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Possibly related PRs

Suggested reviewers

  • yzh119
  • aleozlx
  • yongwww
  • sricketts
  • dhiraj113
  • cyx-6
  • samuellees
  • bkryu
  • nv-yunzheq

Poem

A rabbit hops through barriers bright,
Decoding tiles with pointer's light,
From scratch to V, the syncs align,
FP8 now runs just fine—
Double-buffered, fast, and right! 🐰✨

🚥 Pre-merge checks | ✅ 3 | ❌ 2

❌ Failed checks (1 warning, 1 inconclusive)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 32.26% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
Title check ❓ Inconclusive The title is vague and uses meta language like '[draft] copy of' that doesn't convey the actual technical change; it references another PR instead of summarizing the work. Revise the title to clearly describe the main fix, e.g., 'Fix FP8 transpose buffer deadlock in warp-specialized fmha_v2 on SM90' or similar.
✅ Passed checks (3 passed)
Check name Status Explanation
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.
Description check ✅ Passed The PR description is comprehensive and well-structured, covering root cause, specific code changes, and test modifications, though it deviates from the template structure.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@jimmyzho
Copy link
Copy Markdown
Contributor Author

jimmyzho commented May 9, 2026

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !653 has been created, and the CI pipeline #50736976 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses several issues in the FP8 FMHA v2 implementation, including fixing undefined behavior in synchronization patterns, preventing phase-flip races in the DMA path, and optimizing shared memory usage to fit H100 budgets. It also re-enables FP8 tests on SM90. Feedback focuses on the efficiency and safety of the newly added decode_exact_dynamic_tile_id function, specifically regarding redundant calculations across threads, potential integer overflow risks when casting tile IDs, and redundant global memory accesses within the decoding loop.

Comment on lines +154 to +175
static inline __device__ bool decode_exact_dynamic_tile_id(
bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb,
int& bidh, int& q_step_offset) {
int remaining = static_cast<int>(tile_id);

#pragma unroll 1
for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
int const actual_q_seqlen =
params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx];
int const q_tiles_per_head = compute_dynamic_q_tiles_per_head(actual_q_seqlen);
int const batch_tiles = q_tiles_per_head * params.h;
if (remaining < batch_tiles) {
bidb = batch_idx;
bidh = remaining / q_tiles_per_head;
q_step_offset = (remaining % q_tiles_per_head) * NUM_COMPUTE_GROUPS;
return true;
}
remaining -= batch_tiles;
}

return false;
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The decode_exact_dynamic_tile_id function is executed by all 128 threads in the DMA group (when DMA_GROUP_TRANSPOSE_V is true). Since all threads in the group process the same tile_id_, performing this linear scan in every thread is highly redundant and inefficient, especially as the batch size params.b grows. It would be significantly better to have only the group leader (e.g., thread 0) perform the decode and then broadcast the results (bidb, bidh, q_step_offset) to the rest of the group using __shfl_sync or shared memory.

static inline __device__ bool decode_exact_dynamic_tile_id(
bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb,
int& bidh, int& q_step_offset) {
int remaining = static_cast<int>(tile_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tile_id is a uint32_t, but it is cast to a signed int for remaining. While the total number of tiles is unlikely to exceed $2^{31}-1$ in current workloads, using a signed integer here is risky. If remaining becomes negative, the loop condition remaining < batch_tiles could evaluate to true if batch_tiles is 0 (e.g., for an empty sequence), leading to a division or modulo by zero at line 167-168. It is safer to keep remaining as uint32_t.

      uint32_t remaining = tile_id;

Comment on lines +161 to +162
int const actual_q_seqlen =
params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx];
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop performs redundant global memory loads from params.cu_q_seqlens. In each iteration batch_idx, it loads both params.cu_q_seqlens[batch_idx + 1] and params.cu_q_seqlens[batch_idx]. You can optimize this by loading the initial value before the loop and then only loading the 'next' value in each iteration.

      int prev_cu_len = params.cu_q_seqlens[0];
#pragma unroll 1
      for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
        int const next_cu_len = params.cu_q_seqlens[batch_idx + 1];
        int const actual_q_seqlen = next_cu_len - prev_cu_len;
        prev_cu_len = next_cu_len;

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/jit/attention/fmha_v2/fmha_library.py`:
- Around line 189-191: The unconditional spec["scheduling_mode"] = 1 should be
made conditional so it only applies to the FP8 Hopper transpose-V path; restrict
the override to the FP8/Hopper kernel generation (the path that wires up the
exact tile-id decode in dma.h) and do not change scheduling_mode for other types
(e.g. SM120/BF16/FP16) so run_separate_q_and_kv() can still emit its
SCHEDULING_MODE == 2 balanced specialization. Locate the assignment to
spec["scheduling_mode"] and wrap it with a check for the FP8 Hopper/transpose-V
condition (the same condition used to select the FP8 transpose-V kernel),
leaving other cases untouched.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 4bb99ca6-1882-464a-8a78-e58ce8c383e6

📥 Commits

Reviewing files that changed from the base of the PR and between f6717ff and d352e3b.

📒 Files selected for processing (8)
  • csrc/fmha_v2/fmha/hopper/arrive_wait.h
  • csrc/fmha_v2/fmha/warpspec/circular_buffer.h
  • csrc/fmha_v2/fmha/warpspec/compute.h
  • csrc/fmha_v2/fmha/warpspec/dma.h
  • csrc/fmha_v2/fmha/warpspec/kernel_traits.h
  • flashinfer/jit/attention/fmha_v2/fmha_library.py
  • flashinfer/prefill.py
  • tests/attention/test_fmha_v2_prefill.py
💤 Files with no reviewable changes (1)
  • flashinfer/prefill.py

Comment on lines +189 to +191
# Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
# stay on the persistent scheduler even when a launch mixes different q-tile counts.
spec["scheduling_mode"] = 1
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scope the persistent-scheduler override to the FP8 Hopper path.

Line 191 now forces scheduling_mode = 1 for every generated spec, but the exact tile-id decode added in dma.h is only wired up for the FP8 transpose-V path. That broadens this fix to unrelated SM120/BF16/FP16 kernels and also stops run_separate_q_and_kv() from ever generating its existing SCHEDULING_MODE == 2 balanced specialization.

Suggested fix
-    # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
-    # stay on the persistent scheduler even when a launch mixes different q-tile counts.
-    spec["scheduling_mode"] = 1
+    # Warp-specialized FP8 Hopper kernels use an exact dynamic tile decode in the DMA path, so
+    # they can stay on the persistent scheduler even when a launch mixes different q-tile counts.
+    if warp_specialization and dtype == "e4m3":
+        spec["scheduling_mode"] = 1
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/jit/attention/fmha_v2/fmha_library.py` around lines 189 - 191, The
unconditional spec["scheduling_mode"] = 1 should be made conditional so it only
applies to the FP8 Hopper transpose-V path; restrict the override to the
FP8/Hopper kernel generation (the path that wires up the exact tile-id decode in
dma.h) and do not change scheduling_mode for other types (e.g. SM120/BF16/FP16)
so run_separate_q_and_kv() can still emit its SCHEDULING_MODE == 2 balanced
specialization. Locate the assignment to spec["scheduling_mode"] and wrap it
with a check for the FP8 Hopper/transpose-V condition (the same condition used
to select the FP8 transpose-V kernel), leaving other cases untouched.

Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🧹 Nitpick comments (1)
csrc/fmha_v2/fmha/warpspec/dma.h (1)

154-175: 🏗️ Heavy lift

Avoid an O(batch) scan in the persistent tile decoder.

decode_exact_dynamic_tile_id() re-walks params.cu_q_seqlens from batch 0 for every claimed tile. On the transpose-V persistent path that turns scheduling into O(num_tiles * b) control work plus repeated global-memory reads, which is likely to show up on large-batch / short-sequence workloads. Please consider carrying decode state forward between successive tile IDs or decoding from a precomputed prefix structure instead.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/fmha_v2/fmha/warpspec/dma.h` around lines 154 - 175,
decode_exact_dynamic_tile_id performs an O(b) walk over params.cu_q_seqlens for
each tile_id, causing O(num_tiles * b) control work and repeated global reads;
fix by avoiding full re-scan—either (A) change the caller/iterator to carry
decode state forward so successive calls to decode_exact_dynamic_tile_id advance
from the previous batch index/remaining rather than restarting at batch 0, or
(B) add a small precomputed prefix/index structure (e.g., an array of cumulative
q_tile counts or per-batch q_tiles_per_head) computed once (from
params.cu_q_seqlens and compute_dynamic_q_tiles_per_head) and then use binary
search or direct lookup to map tile_id to bidb/bidh/q_step_offset in O(log b) or
O(1); update decode_exact_dynamic_tile_id (or introduce a new decode_from_index
helper) to use that prefix so it no longer scans params.cu_q_seqlens on every
call.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Nitpick comments:
In `@csrc/fmha_v2/fmha/warpspec/dma.h`:
- Around line 154-175: decode_exact_dynamic_tile_id performs an O(b) walk over
params.cu_q_seqlens for each tile_id, causing O(num_tiles * b) control work and
repeated global reads; fix by avoiding full re-scan—either (A) change the
caller/iterator to carry decode state forward so successive calls to
decode_exact_dynamic_tile_id advance from the previous batch index/remaining
rather than restarting at batch 0, or (B) add a small precomputed prefix/index
structure (e.g., an array of cumulative q_tile counts or per-batch
q_tiles_per_head) computed once (from params.cu_q_seqlens and
compute_dynamic_q_tiles_per_head) and then use binary search or direct lookup to
map tile_id to bidb/bidh/q_step_offset in O(log b) or O(1); update
decode_exact_dynamic_tile_id (or introduce a new decode_from_index helper) to
use that prefix so it no longer scans params.cu_q_seqlens on every call.

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 31ad9d59-ed0b-4d94-bbe0-aabef24b706c

📥 Commits

Reviewing files that changed from the base of the PR and between d352e3b and 5c50d2a.

📒 Files selected for processing (6)
  • csrc/fmha_v2/fmha/warpspec/circular_buffer.h
  • csrc/fmha_v2/fmha/warpspec/dma.h
  • csrc/fmha_v2/fmha/warpspec/kernel_traits.h
  • flashinfer/jit/attention/fmha_v2/fmha_library.py
  • flashinfer/prefill.py
  • tests/attention/test_fmha_v2_prefill.py
💤 Files with no reviewable changes (1)
  • flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (4)
  • tests/attention/test_fmha_v2_prefill.py
  • csrc/fmha_v2/fmha/warpspec/circular_buffer.h
  • csrc/fmha_v2/fmha/warpspec/kernel_traits.h
  • flashinfer/jit/attention/fmha_v2/fmha_library.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants