-
Notifications
You must be signed in to change notification settings - Fork 976
fix(fmha_v2): fix FP8 V-scratch pipeline and varlen scheduler on SM90 #3276
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
74123b2
7d026f3
b5a2cf8
5c50d2a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,7 +99,7 @@ struct DMA { | |
| static_assert(STEP_KV % K_ == 0); | ||
| using Transposer = | ||
| Transposer<typename Kernel_traits::Traits_o, typename Kernel_traits::Cta_tile_o, K_, | ||
| (STEP_KV > 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>; | ||
| (STEP_KV >= 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>; | ||
|
|
||
| struct Device { | ||
| // Only the warpgroup leader initiates mbarriers & TMA operations. | ||
|
|
@@ -147,6 +147,33 @@ struct DMA { | |
| return std::make_pair(kv_idx_start, kv_idx_end); | ||
| } | ||
|
|
||
| static inline __device__ int compute_dynamic_q_tiles_per_head(int actual_q_seqlen) { | ||
| return (actual_q_seqlen + STEP_Q * NUM_COMPUTE_GROUPS - 1) / (STEP_Q * NUM_COMPUTE_GROUPS); | ||
| } | ||
|
|
||
| static inline __device__ bool decode_exact_dynamic_tile_id( | ||
| bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb, | ||
| int& bidh, int& q_step_offset) { | ||
| int remaining = static_cast<int>(tile_id); | ||
|
|
||
| #pragma unroll 1 | ||
| for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) { | ||
| int const actual_q_seqlen = | ||
| params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx]; | ||
|
Comment on lines
+161
to
+162
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The loop performs redundant global memory loads from int prev_cu_len = params.cu_q_seqlens[0];
#pragma unroll 1
for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
int const next_cu_len = params.cu_q_seqlens[batch_idx + 1];
int const actual_q_seqlen = next_cu_len - prev_cu_len;
prev_cu_len = next_cu_len; |
||
| int const q_tiles_per_head = compute_dynamic_q_tiles_per_head(actual_q_seqlen); | ||
| int const batch_tiles = q_tiles_per_head * params.h; | ||
| if (remaining < batch_tiles) { | ||
| bidb = batch_idx; | ||
| bidh = remaining / q_tiles_per_head; | ||
| q_step_offset = (remaining % q_tiles_per_head) * NUM_COMPUTE_GROUPS; | ||
| return true; | ||
| } | ||
| remaining -= batch_tiles; | ||
| } | ||
|
|
||
| return false; | ||
| } | ||
|
Comment on lines
+154
to
+175
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
|
||
| //////////////////////////////////////////////////////////////////////////////////////////// | ||
|
|
||
| // Packed contiguous QKV input. | ||
|
|
@@ -181,20 +208,27 @@ struct DMA { | |
| bidh = tile_id_ % params.h; | ||
| bidb = tile_id_ / params.h; | ||
| } else { | ||
| // Balanced dynamic scheduling | ||
| if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) { | ||
| q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * | ||
| NUM_COMPUTE_GROUPS; | ||
| tmp = tile_id_ % (params.b * params.h); | ||
| bidh = tmp / params.b; | ||
| bidb = tmp % params.b; | ||
| q_steps = NUM_COMPUTE_GROUPS; | ||
| } else { // Unbalanced dynamic scheduling | ||
| bidb = tile_id_ / (params.h * params.num_tiles_per_head); | ||
| tmp = tile_id_ % (params.h * params.num_tiles_per_head); | ||
| bidh = tmp / params.num_tiles_per_head; | ||
| q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS; | ||
| if constexpr (DMA_GROUP_TRANSPOSE_V) { | ||
| q_steps = NUM_COMPUTE_GROUPS; | ||
| if (!decode_exact_dynamic_tile_id(params, tile_id_, bidb, bidh, q_step_offset)) { | ||
| break; | ||
| } | ||
| } else { | ||
| // Balanced dynamic scheduling | ||
| if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) { | ||
| q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) * | ||
| NUM_COMPUTE_GROUPS; | ||
| tmp = tile_id_ % (params.b * params.h); | ||
| bidh = tmp / params.b; | ||
| bidb = tmp % params.b; | ||
| q_steps = NUM_COMPUTE_GROUPS; | ||
| } else { // Unbalanced dynamic scheduling | ||
| bidb = tile_id_ / (params.h * params.num_tiles_per_head); | ||
| tmp = tile_id_ % (params.h * params.num_tiles_per_head); | ||
| bidh = tmp / params.num_tiles_per_head; | ||
| q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS; | ||
| q_steps = NUM_COMPUTE_GROUPS; | ||
| } | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -330,6 +364,13 @@ struct DMA { | |
| if (SCHEDULING_MODE == 0) { | ||
| bidh = tile_id_ % params.h; | ||
| bidb = tile_id_ / params.h; | ||
| } else if constexpr (DMA_GROUP_TRANSPOSE_V) { | ||
| q_steps = NUM_COMPUTE_GROUPS; | ||
| int q_step_offset; | ||
| if (!decode_exact_dynamic_tile_id(params, tile_id_, bidb, bidh, q_step_offset)) { | ||
| break; | ||
| } | ||
| local_q_tile_offset = q_step_offset * STEP_Q; | ||
| } else if (SCHEDULING_MODE == 1) { | ||
| bidb = tile_id_ / (params.h * params.num_tiles_per_head); | ||
| tmp = tile_id_ % (params.h * params.num_tiles_per_head); | ||
|
|
@@ -616,14 +657,22 @@ struct DMA { | |
| Transposer transposer(threadIdx.x % NUM_THREADS_IN_DMA_GROUP); | ||
|
|
||
| // Src buffer available | ||
| int ready = cbr_v_scratch.peek(); | ||
| int ready = cbr_v_scratch.peek(v_scratch_barrier_id); | ||
| if (!ready) { | ||
| cbr_v_scratch.wait(); | ||
| cbr_v_scratch.wait(v_scratch_barrier_id); | ||
| } | ||
| uint32_t smem_v_src = __cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id]); | ||
| uint32_t smem_v_src = | ||
| __cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V]); | ||
|
|
||
| // Dst buffer available | ||
| int v_barrier_id = cbw_v.threadReserve(); | ||
| // NOTE(bobboli): Sync all DMA threads after consumer-bar wait to prevent phase-flip race. | ||
| // Without this, thread 0 can race ahead, commit V (triggering compute to consume the slot | ||
| // and flip the consumed-barrier phase), then wrap around in threadReserve() with a new | ||
| // expected phase, while slow DMA warps are still waiting on the old expected phase of the | ||
| // now-flipped barrier -> deadlock at bar.sync 1, 128 in transpose_v_tile. Same hazard as | ||
| // described for push_with_sync (see comment near run_packed_qkv). | ||
| named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); | ||
| uint32_t smem_v_dst = __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V]); | ||
|
|
||
| // Explicitly transpose the v buffer in smem for fp8. | ||
|
|
@@ -682,7 +731,7 @@ struct DMA { | |
| fence_view_async_shared(); // Commit STSM | ||
| named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); // Sync before signaling | ||
| cbw_v.threadCommit(elect_one_, v_barrier_id); // Signal readiness | ||
| cbr_v_scratch.pop(elect_one_); // Advance to next phase | ||
| cbr_v_scratch.pop(elect_one_, v_scratch_barrier_id); // Advance to next phase | ||
| } | ||
|
|
||
| inline __device__ void get_next_tile_id(int local_wid, int tiw, uint32_t smem_tile_id, | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -186,7 +186,9 @@ def generate_kernel_spec( | |
|
|
||
| # Override class defaults that always differ | ||
| spec["flash_attention"] = True # Class default is False | ||
| spec["scheduling_mode"] = 1 # Class default is 0 | ||
| # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can | ||
| # stay on the persistent scheduler even when a launch mixes different q-tile counts. | ||
| spec["scheduling_mode"] = 1 | ||
|
Comment on lines
+189
to
+191
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Scope the persistent-scheduler override to the FP8 Hopper path. Line 191 now forces Suggested fix- # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
- # stay on the persistent scheduler even when a launch mixes different q-tile counts.
- spec["scheduling_mode"] = 1
+ # Warp-specialized FP8 Hopper kernels use an exact dynamic tile decode in the DMA path, so
+ # they can stay on the persistent scheduler even when a launch mixes different q-tile counts.
+ if warp_specialization and dtype == "e4m3":
+ spec["scheduling_mode"] = 1🤖 Prompt for AI Agents |
||
|
|
||
| # # SM-specific configuration | ||
| # if warp_specialization: | ||
|
|
@@ -234,6 +236,8 @@ def generate_kernel_spec( | |
| # FP8: round_up to multiples of 128 -> D in {32, 64, 128, 256} | ||
| # (head_size=160 pads to D=256 for FP8 vs D=192 for FP16) | ||
| # | ||
| effective_output_dtype = output_dtype if output_dtype is not None else dtype | ||
|
|
||
| if warp_specialization: | ||
| spec["warp_specialization"] = True | ||
| spec["sm_mma"] = 90 | ||
|
|
@@ -267,7 +271,16 @@ def generate_kernel_spec( | |
| spec["kv_loop_step"] = 256 | ||
| else: | ||
| # D=256 (FP8 pads head_size>128 to 256 due to 128-byte alignment): | ||
| # smem = 32 + 64 + 64 + 32 = ~192KB with KV_BUF=2 | ||
| # base smem = 32 + 64 + 64 + 32 = ~192KB with KV_BUF=2. | ||
| # | ||
| # FP8->FP8 output kernels add two output staging buffers in shared memory | ||
| # (kernel_traits.h:514-523), which pushes STEP_KV=128 over H100's 228KB | ||
| # dynamic shared-memory budget and causes cudaFuncSetAttribute(...) | ||
| # to fail with cudaErrorInvalidValue. Keep STEP_KV=128 for numerical | ||
| # stability, but drop KV buffering depth to 1 for fp8 output so the | ||
| # kernel fits H100's smem budget. | ||
| if effective_output_dtype in ["e4m3", "e4m3_fp32"]: | ||
| spec["kv_tile_buffers"] = 1 | ||
| spec["kv_loop_step"] = 128 | ||
| else: | ||
| raise ValueError(f"Unsupported dtype: {dtype}") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The$2^{31}-1$ in current workloads, using a signed integer here is risky. If
tile_idis auint32_t, but it is cast to a signedintforremaining. While the total number of tiles is unlikely to exceedremainingbecomes negative, the loop conditionremaining < batch_tilescould evaluate to true ifbatch_tilesis 0 (e.g., for an empty sequence), leading to a division or modulo by zero at line 167-168. It is safer to keepremainingasuint32_t.