[draft] copy of Fix barrier deadlock in fmha_v2 fp8+head_dim=256 transpose_v_tile#3276
[draft] copy of Fix barrier deadlock in fmha_v2 fp8+head_dim=256 transpose_v_tile#3276jimmyzho wants to merge 4 commits into
Conversation
📝 WalkthroughWalkthroughCircularBufferReader gains pointer-parameterized peek/wait/advance/pop operations to support flexible barrier-specific consumption. V-transpose scratch buffers now size to KV_BUFFERS depth. DMA scheduling introduces exact dynamic tile-id decoding to map tile indices to precise batch/head/Q-step positions. V-transpose synchronization uses the new pointer API for barrier-id-specific operations. FP8 support is enabled through output dtype configuration and removal of runtime validation restrictions. ChangesFMHA v2 V-Transpose Pipeline and FP8 Support
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 3 | ❌ 2❌ Failed checks (1 warning, 1 inconclusive)
✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request addresses several issues in the FP8 FMHA v2 implementation, including fixing undefined behavior in synchronization patterns, preventing phase-flip races in the DMA path, and optimizing shared memory usage to fit H100 budgets. It also re-enables FP8 tests on SM90. Feedback focuses on the efficiency and safety of the newly added decode_exact_dynamic_tile_id function, specifically regarding redundant calculations across threads, potential integer overflow risks when casting tile IDs, and redundant global memory accesses within the decoding loop.
| static inline __device__ bool decode_exact_dynamic_tile_id( | ||
| bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb, | ||
| int& bidh, int& q_step_offset) { | ||
| int remaining = static_cast<int>(tile_id); | ||
|
|
||
| #pragma unroll 1 | ||
| for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) { | ||
| int const actual_q_seqlen = | ||
| params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx]; | ||
| int const q_tiles_per_head = compute_dynamic_q_tiles_per_head(actual_q_seqlen); | ||
| int const batch_tiles = q_tiles_per_head * params.h; | ||
| if (remaining < batch_tiles) { | ||
| bidb = batch_idx; | ||
| bidh = remaining / q_tiles_per_head; | ||
| q_step_offset = (remaining % q_tiles_per_head) * NUM_COMPUTE_GROUPS; | ||
| return true; | ||
| } | ||
| remaining -= batch_tiles; | ||
| } | ||
|
|
||
| return false; | ||
| } |
There was a problem hiding this comment.
The decode_exact_dynamic_tile_id function is executed by all 128 threads in the DMA group (when DMA_GROUP_TRANSPOSE_V is true). Since all threads in the group process the same tile_id_, performing this linear scan in every thread is highly redundant and inefficient, especially as the batch size params.b grows. It would be significantly better to have only the group leader (e.g., thread 0) perform the decode and then broadcast the results (bidb, bidh, q_step_offset) to the rest of the group using __shfl_sync or shared memory.
| static inline __device__ bool decode_exact_dynamic_tile_id( | ||
| bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb, | ||
| int& bidh, int& q_step_offset) { | ||
| int remaining = static_cast<int>(tile_id); |
There was a problem hiding this comment.
The tile_id is a uint32_t, but it is cast to a signed int for remaining. While the total number of tiles is unlikely to exceed remaining becomes negative, the loop condition remaining < batch_tiles could evaluate to true if batch_tiles is 0 (e.g., for an empty sequence), leading to a division or modulo by zero at line 167-168. It is safer to keep remaining as uint32_t.
uint32_t remaining = tile_id;| int const actual_q_seqlen = | ||
| params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx]; |
There was a problem hiding this comment.
The loop performs redundant global memory loads from params.cu_q_seqlens. In each iteration batch_idx, it loads both params.cu_q_seqlens[batch_idx + 1] and params.cu_q_seqlens[batch_idx]. You can optimize this by loading the initial value before the loop and then only loading the 'next' value in each iteration.
int prev_cu_len = params.cu_q_seqlens[0];
#pragma unroll 1
for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
int const next_cu_len = params.cu_q_seqlens[batch_idx + 1];
int const actual_q_seqlen = next_cu_len - prev_cu_len;
prev_cu_len = next_cu_len;There was a problem hiding this comment.
Actionable comments posted: 1
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/jit/attention/fmha_v2/fmha_library.py`:
- Around line 189-191: The unconditional spec["scheduling_mode"] = 1 should be
made conditional so it only applies to the FP8 Hopper transpose-V path; restrict
the override to the FP8/Hopper kernel generation (the path that wires up the
exact tile-id decode in dma.h) and do not change scheduling_mode for other types
(e.g. SM120/BF16/FP16) so run_separate_q_and_kv() can still emit its
SCHEDULING_MODE == 2 balanced specialization. Locate the assignment to
spec["scheduling_mode"] and wrap it with a check for the FP8 Hopper/transpose-V
condition (the same condition used to select the FP8 transpose-V kernel),
leaving other cases untouched.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 4bb99ca6-1882-464a-8a78-e58ce8c383e6
📒 Files selected for processing (8)
csrc/fmha_v2/fmha/hopper/arrive_wait.hcsrc/fmha_v2/fmha/warpspec/circular_buffer.hcsrc/fmha_v2/fmha/warpspec/compute.hcsrc/fmha_v2/fmha/warpspec/dma.hcsrc/fmha_v2/fmha/warpspec/kernel_traits.hflashinfer/jit/attention/fmha_v2/fmha_library.pyflashinfer/prefill.pytests/attention/test_fmha_v2_prefill.py
💤 Files with no reviewable changes (1)
- flashinfer/prefill.py
| # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can | ||
| # stay on the persistent scheduler even when a launch mixes different q-tile counts. | ||
| spec["scheduling_mode"] = 1 |
There was a problem hiding this comment.
Scope the persistent-scheduler override to the FP8 Hopper path.
Line 191 now forces scheduling_mode = 1 for every generated spec, but the exact tile-id decode added in dma.h is only wired up for the FP8 transpose-V path. That broadens this fix to unrelated SM120/BF16/FP16 kernels and also stops run_separate_q_and_kv() from ever generating its existing SCHEDULING_MODE == 2 balanced specialization.
Suggested fix
- # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
- # stay on the persistent scheduler even when a launch mixes different q-tile counts.
- spec["scheduling_mode"] = 1
+ # Warp-specialized FP8 Hopper kernels use an exact dynamic tile decode in the DMA path, so
+ # they can stay on the persistent scheduler even when a launch mixes different q-tile counts.
+ if warp_specialization and dtype == "e4m3":
+ spec["scheduling_mode"] = 1🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
In `@flashinfer/jit/attention/fmha_v2/fmha_library.py` around lines 189 - 191, The
unconditional spec["scheduling_mode"] = 1 should be made conditional so it only
applies to the FP8 Hopper transpose-V path; restrict the override to the
FP8/Hopper kernel generation (the path that wires up the exact tile-id decode in
dma.h) and do not change scheduling_mode for other types (e.g. SM120/BF16/FP16)
so run_separate_q_and_kv() can still emit its SCHEDULING_MODE == 2 balanced
specialization. Locate the assignment to spec["scheduling_mode"] and wrap it
with a check for the FP8 Hopper/transpose-V condition (the same condition used
to select the FP8 transpose-V kernel), leaving other cases untouched.
There was a problem hiding this comment.
🧹 Nitpick comments (1)
csrc/fmha_v2/fmha/warpspec/dma.h (1)
154-175: 🏗️ Heavy liftAvoid an O(batch) scan in the persistent tile decoder.
decode_exact_dynamic_tile_id()re-walksparams.cu_q_seqlensfrom batch 0 for every claimed tile. On the transpose-V persistent path that turns scheduling into O(num_tiles * b) control work plus repeated global-memory reads, which is likely to show up on large-batch / short-sequence workloads. Please consider carrying decode state forward between successive tile IDs or decoding from a precomputed prefix structure instead.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/fmha_v2/fmha/warpspec/dma.h` around lines 154 - 175, decode_exact_dynamic_tile_id performs an O(b) walk over params.cu_q_seqlens for each tile_id, causing O(num_tiles * b) control work and repeated global reads; fix by avoiding full re-scan—either (A) change the caller/iterator to carry decode state forward so successive calls to decode_exact_dynamic_tile_id advance from the previous batch index/remaining rather than restarting at batch 0, or (B) add a small precomputed prefix/index structure (e.g., an array of cumulative q_tile counts or per-batch q_tiles_per_head) computed once (from params.cu_q_seqlens and compute_dynamic_q_tiles_per_head) and then use binary search or direct lookup to map tile_id to bidb/bidh/q_step_offset in O(log b) or O(1); update decode_exact_dynamic_tile_id (or introduce a new decode_from_index helper) to use that prefix so it no longer scans params.cu_q_seqlens on every call.
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Nitpick comments:
In `@csrc/fmha_v2/fmha/warpspec/dma.h`:
- Around line 154-175: decode_exact_dynamic_tile_id performs an O(b) walk over
params.cu_q_seqlens for each tile_id, causing O(num_tiles * b) control work and
repeated global reads; fix by avoiding full re-scan—either (A) change the
caller/iterator to carry decode state forward so successive calls to
decode_exact_dynamic_tile_id advance from the previous batch index/remaining
rather than restarting at batch 0, or (B) add a small precomputed prefix/index
structure (e.g., an array of cumulative q_tile counts or per-batch
q_tiles_per_head) computed once (from params.cu_q_seqlens and
compute_dynamic_q_tiles_per_head) and then use binary search or direct lookup to
map tile_id to bidb/bidh/q_step_offset in O(log b) or O(1); update
decode_exact_dynamic_tile_id (or introduce a new decode_from_index helper) to
use that prefix so it no longer scans params.cu_q_seqlens on every call.
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 31ad9d59-ed0b-4d94-bbe0-aabef24b706c
📒 Files selected for processing (6)
csrc/fmha_v2/fmha/warpspec/circular_buffer.hcsrc/fmha_v2/fmha/warpspec/dma.hcsrc/fmha_v2/fmha/warpspec/kernel_traits.hflashinfer/jit/attention/fmha_v2/fmha_library.pyflashinfer/prefill.pytests/attention/test_fmha_v2_prefill.py
💤 Files with no reviewable changes (1)
- flashinfer/prefill.py
🚧 Files skipped from review as they are similar to previous changes (4)
- tests/attention/test_fmha_v2_prefill.py
- csrc/fmha_v2/fmha/warpspec/circular_buffer.h
- csrc/fmha_v2/fmha/warpspec/kernel_traits.h
- flashinfer/jit/attention/fmha_v2/fmha_library.py
Forked and modified from @bobboli's Original PR: #2957
Root cause
On Hopper (SM90), the warp-specialized FP8 FMHAv2 kernel runs a dedicated DMA warpgroup that explicitly transposes the V tile (via STSM) from
smem_v_scratchintosmem_vbefore BMM2, because FP8 QGMMA can't accept a transposed B-operand. The V-scratch FIFO between TMA load and STSM transpose was hardcoded to depth 1 (V_SCRATCH_BUFFERS = DMA_GROUP_TRANSPOSE_V ? 1 : 0).At depth 1, producer (
_wptr) and consumer (_rptr) share one physical slot and the phase bit alone distinguishes epochs N and N+1. Any momentary ordering inversion between the 128 transposer threads and DMA thread 0 corrupts the phase and deadlocks at the nextbar.sync— observable as a hang in FP8 prefill athead_dim=256after enabling FP8 output coverage.The fix is to raise the V-scratch FIFO to depth
KV_BUFFERS(≥ 2 for FP8 head_size ≤ 128). Raising the depth unmasks two latent bugs that depth-1 was hiding, and exposes a third race the kernel didn't previously have a window for. All three need to land together forKV_BUFFERS > 1to be correct, and a fourth fix is needed for the persistent scheduler on the FP8 path.Changes
kernel_traits.h— raise V-scratch FIFO depth from 1 toKV_BUFFERS. Also tighten the V-transposer UNROLL gate fromSTEP_KV > 128toSTEP_KV >= 128to avoid register spilling at the deeper buffer state.dma.h— three coupled fixes the new depth requires:smem_v_scratch[v_scratch_barrier_id]→smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V]. At depth 1v_scratch_barrier_idis always 0 so the missing stride was dormant; at depth > 1 it would corrupt the transposed V output.v_scratch_barrier_idintocbr_v_scratch.peek/wait/pop()instead of letting the reader use its internal_rptr. Required because_rptrwalks independently of the slot the writer reserved when the V tile was loaded.named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP)immediately afterthreadReserve()intranspose_v_tile. Without it, DMA thread 0 can race ahead into iteration N+1 (advancing the consumed-barrier phase) while warps 1–3 still poll iteration N's expected phase, deadlocking the warpgroup. Canonical leader-ahead fix, same patternpush_with_syncuses.circular_buffer.h— addptr-taking overloads ofpeek/wait/advance/popso callers can wait on a specific slot rather than the reader's internal_rptr. Required by thedma.hfix above.FP8 persistent scheduler for ragged q-tiles —
decode_exact_dynamic_tile_idwalkscu_q_seqlensand computes each batch element's q-tile count on the fly. The old scheduler assumed a uniformnum_tiles_per_headacross the batch and relied on downstream skip-tile logic to ignore invalid tiles. That worked for FP16/BF16, but on the FP8 path the DMA warpgroup does real work for those ghost tiles (TMA + transpose) and the barrier machinery doesn't unwind cleanly. Gated onDMA_GROUP_TRANSPOSE_Vso FP16/BF16 keeps its tuned scheduling.FP8-output coverage + smem budget —
fmha_library.pynow dropskv_tile_buffersto 1 for FP8-outputhead_dim=256(FP8→FP8 adds two output staging buffers, pushing past H100's 228KB cap). Depth-1 is safe at this config because the newnamed_barrier_waitafterthreadReservekeeps the DMA warpgroup synchronized across iterations.Test
tests/attention/test_fmha_v2_prefill.py.