Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions csrc/fmha_v2/fmha/warpspec/circular_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,25 +126,31 @@ class CircularBufferReader {
}

/* Peek at the head */
inline __device__ int peek() {
return _entryProducedBarriers.bar_peek(_rptr, (_phase >> _rptr) & 1);
inline __device__ int peek(int ptr) {
return _entryProducedBarriers.bar_peek(ptr, (_phase >> ptr) & 1);
}

inline __device__ int peek() { return peek(_rptr); }

/* Wait for the head to be ready */
inline __device__ int wait() {
_entryProducedBarriers.bar_wait(_rptr, (_phase >> _rptr) & 1);
return _rptr;
inline __device__ int wait(int ptr) {
_entryProducedBarriers.bar_wait(ptr, (_phase >> ptr) & 1);
return ptr;
}

inline __device__ int wait() { return wait(_rptr); }

/* Advance the head pointer */
inline __device__ void advance() {
_phase ^= (1 << _rptr);
_rptr += 1;
inline __device__ void advance(int ptr) {
_phase ^= (1 << ptr);
_rptr = ptr + 1;
if (_rptr >= DEPTH) {
_rptr = 0;
}
}

inline __device__ void advance() { advance(_rptr); }

inline __device__ int ptr() { return _rptr; }

inline __device__ uint32_t phase() { return _phase; }
Expand All @@ -165,11 +171,13 @@ class CircularBufferReader {
/* Simplification of complete and advance for cases
where they don't need to be reordered/separated for performance
*/
inline __device__ void pop(int tid0) {
complete(tid0, _rptr);
advance();
inline __device__ void pop(int tid0, int ptr) {
complete(tid0, ptr);
advance(ptr);
}

inline __device__ void pop(int tid0) { pop(tid0, _rptr); }

/* Overrides for pointer and phase. Used for shared buffers */
inline __device__ void setPtr(int ptr) { _rptr = ptr; }

Expand Down
85 changes: 67 additions & 18 deletions csrc/fmha_v2/fmha/warpspec/dma.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ struct DMA {
static_assert(STEP_KV % K_ == 0);
using Transposer =
Transposer<typename Kernel_traits::Traits_o, typename Kernel_traits::Cta_tile_o, K_,
(STEP_KV > 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>;
(STEP_KV >= 128 || SLIDING_OR_CHUNKED_ATTENTION) ? 1 : 2 /* UNROLL */>;

struct Device {
// Only the warpgroup leader initiates mbarriers & TMA operations.
Expand Down Expand Up @@ -147,6 +147,33 @@ struct DMA {
return std::make_pair(kv_idx_start, kv_idx_end);
}

static inline __device__ int compute_dynamic_q_tiles_per_head(int actual_q_seqlen) {
return (actual_q_seqlen + STEP_Q * NUM_COMPUTE_GROUPS - 1) / (STEP_Q * NUM_COMPUTE_GROUPS);
}

static inline __device__ bool decode_exact_dynamic_tile_id(
bert::Fused_multihead_attention_params_v2 const& params, uint32_t tile_id, int& bidb,
int& bidh, int& q_step_offset) {
int remaining = static_cast<int>(tile_id);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The tile_id is a uint32_t, but it is cast to a signed int for remaining. While the total number of tiles is unlikely to exceed $2^{31}-1$ in current workloads, using a signed integer here is risky. If remaining becomes negative, the loop condition remaining < batch_tiles could evaluate to true if batch_tiles is 0 (e.g., for an empty sequence), leading to a division or modulo by zero at line 167-168. It is safer to keep remaining as uint32_t.

      uint32_t remaining = tile_id;


#pragma unroll 1
for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
int const actual_q_seqlen =
params.cu_q_seqlens[batch_idx + 1] - params.cu_q_seqlens[batch_idx];
Comment on lines +161 to +162
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The loop performs redundant global memory loads from params.cu_q_seqlens. In each iteration batch_idx, it loads both params.cu_q_seqlens[batch_idx + 1] and params.cu_q_seqlens[batch_idx]. You can optimize this by loading the initial value before the loop and then only loading the 'next' value in each iteration.

      int prev_cu_len = params.cu_q_seqlens[0];
#pragma unroll 1
      for (int batch_idx = 0; batch_idx < params.b; ++batch_idx) {
        int const next_cu_len = params.cu_q_seqlens[batch_idx + 1];
        int const actual_q_seqlen = next_cu_len - prev_cu_len;
        prev_cu_len = next_cu_len;

int const q_tiles_per_head = compute_dynamic_q_tiles_per_head(actual_q_seqlen);
int const batch_tiles = q_tiles_per_head * params.h;
if (remaining < batch_tiles) {
bidb = batch_idx;
bidh = remaining / q_tiles_per_head;
q_step_offset = (remaining % q_tiles_per_head) * NUM_COMPUTE_GROUPS;
return true;
}
remaining -= batch_tiles;
}

return false;
}
Comment on lines +154 to +175
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The decode_exact_dynamic_tile_id function is executed by all 128 threads in the DMA group (when DMA_GROUP_TRANSPOSE_V is true). Since all threads in the group process the same tile_id_, performing this linear scan in every thread is highly redundant and inefficient, especially as the batch size params.b grows. It would be significantly better to have only the group leader (e.g., thread 0) perform the decode and then broadcast the results (bidb, bidh, q_step_offset) to the rest of the group using __shfl_sync or shared memory.


////////////////////////////////////////////////////////////////////////////////////////////

// Packed contiguous QKV input.
Expand Down Expand Up @@ -181,20 +208,27 @@ struct DMA {
bidh = tile_id_ % params.h;
bidb = tile_id_ / params.h;
} else {
// Balanced dynamic scheduling
if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) {
q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) *
NUM_COMPUTE_GROUPS;
tmp = tile_id_ % (params.b * params.h);
bidh = tmp / params.b;
bidb = tmp % params.b;
q_steps = NUM_COMPUTE_GROUPS;
} else { // Unbalanced dynamic scheduling
bidb = tile_id_ / (params.h * params.num_tiles_per_head);
tmp = tile_id_ % (params.h * params.num_tiles_per_head);
bidh = tmp / params.num_tiles_per_head;
q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS;
if constexpr (DMA_GROUP_TRANSPOSE_V) {
q_steps = NUM_COMPUTE_GROUPS;
if (!decode_exact_dynamic_tile_id(params, tile_id_, bidb, bidh, q_step_offset)) {
break;
}
} else {
// Balanced dynamic scheduling
if (CAUSAL_MASK && !SLIDING_OR_CHUNKED_ATTENTION && params.use_balanced_scheduling) {
q_step_offset = (params.num_tiles_per_head - 1 - tile_id_ / (params.b * params.h)) *
NUM_COMPUTE_GROUPS;
tmp = tile_id_ % (params.b * params.h);
bidh = tmp / params.b;
bidb = tmp % params.b;
q_steps = NUM_COMPUTE_GROUPS;
} else { // Unbalanced dynamic scheduling
bidb = tile_id_ / (params.h * params.num_tiles_per_head);
tmp = tile_id_ % (params.h * params.num_tiles_per_head);
bidh = tmp / params.num_tiles_per_head;
q_step_offset = tmp % params.num_tiles_per_head * NUM_COMPUTE_GROUPS;
q_steps = NUM_COMPUTE_GROUPS;
}
}
}

Expand Down Expand Up @@ -330,6 +364,13 @@ struct DMA {
if (SCHEDULING_MODE == 0) {
bidh = tile_id_ % params.h;
bidb = tile_id_ / params.h;
} else if constexpr (DMA_GROUP_TRANSPOSE_V) {
q_steps = NUM_COMPUTE_GROUPS;
int q_step_offset;
if (!decode_exact_dynamic_tile_id(params, tile_id_, bidb, bidh, q_step_offset)) {
break;
}
local_q_tile_offset = q_step_offset * STEP_Q;
} else if (SCHEDULING_MODE == 1) {
bidb = tile_id_ / (params.h * params.num_tiles_per_head);
tmp = tile_id_ % (params.h * params.num_tiles_per_head);
Expand Down Expand Up @@ -616,14 +657,22 @@ struct DMA {
Transposer transposer(threadIdx.x % NUM_THREADS_IN_DMA_GROUP);

// Src buffer available
int ready = cbr_v_scratch.peek();
int ready = cbr_v_scratch.peek(v_scratch_barrier_id);
if (!ready) {
cbr_v_scratch.wait();
cbr_v_scratch.wait(v_scratch_barrier_id);
}
uint32_t smem_v_src = __cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id]);
uint32_t smem_v_src =
__cvta_generic_to_shared(&shared->smem_v_scratch[v_scratch_barrier_id * TILE_SIZE_V]);

// Dst buffer available
int v_barrier_id = cbw_v.threadReserve();
// NOTE(bobboli): Sync all DMA threads after consumer-bar wait to prevent phase-flip race.
// Without this, thread 0 can race ahead, commit V (triggering compute to consume the slot
// and flip the consumed-barrier phase), then wrap around in threadReserve() with a new
// expected phase, while slow DMA warps are still waiting on the old expected phase of the
// now-flipped barrier -> deadlock at bar.sync 1, 128 in transpose_v_tile. Same hazard as
// described for push_with_sync (see comment near run_packed_qkv).
named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP);
uint32_t smem_v_dst = __cvta_generic_to_shared(&shared->smem_v[v_barrier_id * TILE_SIZE_V]);

// Explicitly transpose the v buffer in smem for fp8.
Expand Down Expand Up @@ -682,7 +731,7 @@ struct DMA {
fence_view_async_shared(); // Commit STSM
named_barrier_wait(SYNC_BARRIER, NUM_THREADS_IN_DMA_GROUP); // Sync before signaling
cbw_v.threadCommit(elect_one_, v_barrier_id); // Signal readiness
cbr_v_scratch.pop(elect_one_); // Advance to next phase
cbr_v_scratch.pop(elect_one_, v_scratch_barrier_id); // Advance to next phase
}

inline __device__ void get_next_tile_id(int local_wid, int tiw, uint32_t smem_tile_id,
Expand Down
5 changes: 3 additions & 2 deletions csrc/fmha_v2/fmha/warpspec/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,9 @@ struct Kernel_traits {
std::is_same<Element_data_type, fmha::e5m2_t>::value)
};

// The number of smem scratch buffer for staging V transpose for Hopper QGMMA
enum { V_SCRATCH_BUFFERS = DMA_GROUP_TRANSPOSE_V ? 1 : 0 };
// Reuse the KV double-buffer depth for fp8 V transpose scratch so the DMA producer/consumer do
// not immediately wrap the same slot on every iteration.
enum { V_SCRATCH_BUFFERS = DMA_GROUP_TRANSPOSE_V ? KV_BUFFERS : 0 };

// The number of compute warpgroups (128 threads per warpgroup).
enum { NUM_COMPUTE_GROUPS = NUM_COMPUTE_GROUPS_ };
Expand Down
17 changes: 15 additions & 2 deletions flashinfer/jit/attention/fmha_v2/fmha_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,9 @@ def generate_kernel_spec(

# Override class defaults that always differ
spec["flash_attention"] = True # Class default is False
spec["scheduling_mode"] = 1 # Class default is 0
# Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
# stay on the persistent scheduler even when a launch mixes different q-tile counts.
spec["scheduling_mode"] = 1
Comment on lines +189 to +191
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Scope the persistent-scheduler override to the FP8 Hopper path.

Line 191 now forces scheduling_mode = 1 for every generated spec, but the exact tile-id decode added in dma.h is only wired up for the FP8 transpose-V path. That broadens this fix to unrelated SM120/BF16/FP16 kernels and also stops run_separate_q_and_kv() from ever generating its existing SCHEDULING_MODE == 2 balanced specialization.

Suggested fix
-    # Warp-specialized fp8 kernels use an exact dynamic tile decode in the DMA path, so they can
-    # stay on the persistent scheduler even when a launch mixes different q-tile counts.
-    spec["scheduling_mode"] = 1
+    # Warp-specialized FP8 Hopper kernels use an exact dynamic tile decode in the DMA path, so
+    # they can stay on the persistent scheduler even when a launch mixes different q-tile counts.
+    if warp_specialization and dtype == "e4m3":
+        spec["scheduling_mode"] = 1
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/jit/attention/fmha_v2/fmha_library.py` around lines 189 - 191, The
unconditional spec["scheduling_mode"] = 1 should be made conditional so it only
applies to the FP8 Hopper transpose-V path; restrict the override to the
FP8/Hopper kernel generation (the path that wires up the exact tile-id decode in
dma.h) and do not change scheduling_mode for other types (e.g. SM120/BF16/FP16)
so run_separate_q_and_kv() can still emit its SCHEDULING_MODE == 2 balanced
specialization. Locate the assignment to spec["scheduling_mode"] and wrap it
with a check for the FP8 Hopper/transpose-V condition (the same condition used
to select the FP8 transpose-V kernel), leaving other cases untouched.


# # SM-specific configuration
# if warp_specialization:
Expand Down Expand Up @@ -234,6 +236,8 @@ def generate_kernel_spec(
# FP8: round_up to multiples of 128 -> D in {32, 64, 128, 256}
# (head_size=160 pads to D=256 for FP8 vs D=192 for FP16)
#
effective_output_dtype = output_dtype if output_dtype is not None else dtype

if warp_specialization:
spec["warp_specialization"] = True
spec["sm_mma"] = 90
Expand Down Expand Up @@ -267,7 +271,16 @@ def generate_kernel_spec(
spec["kv_loop_step"] = 256
else:
# D=256 (FP8 pads head_size>128 to 256 due to 128-byte alignment):
# smem = 32 + 64 + 64 + 32 = ~192KB with KV_BUF=2
# base smem = 32 + 64 + 64 + 32 = ~192KB with KV_BUF=2.
#
# FP8->FP8 output kernels add two output staging buffers in shared memory
# (kernel_traits.h:514-523), which pushes STEP_KV=128 over H100's 228KB
# dynamic shared-memory budget and causes cudaFuncSetAttribute(...)
# to fail with cudaErrorInvalidValue. Keep STEP_KV=128 for numerical
# stability, but drop KV buffering depth to 1 for fp8 output so the
# kernel fits H100's smem budget.
if effective_output_dtype in ["e4m3", "e4m3_fp32"]:
spec["kv_tile_buffers"] = 1
spec["kv_loop_step"] = 128
else:
raise ValueError(f"Unsupported dtype: {dtype}")
Expand Down
10 changes: 0 additions & 10 deletions flashinfer/prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4618,16 +4618,6 @@ def trtllm_fmha_v2_prefill(
"FP8 (e4m3) is not yet supported for FMHAv2 on SM120 (Blackwell). "
"Use fp16 or bf16 instead."
)
if uses_sliding_window and input_layout in ("PACKED_QKV", "CONTIGUOUS_Q_KV"):
_num_kv_heads = (
num_qo_heads if input_layout == "PACKED_QKV" else k_cache.shape[2]
)
if batch_size == 16 and _num_kv_heads == 4 and head_dim_v == 256:
raise ValueError(
"FP8 (e4m3) sliding window attention with batch_size=16, "
"num_kv_heads=4, head_dim=256 is not supported for "
f"{input_layout} layout due to a known issue."
)
# Always pass 1.0: the C++ auto-detect (scale_softmax == 0.0) handles FP16/INT8/E4M3
# but has no branch for BF16, where 0.0 would zero out the softmax output.
scale_softmax = 1.0
Expand Down
3 changes: 1 addition & 2 deletions tests/attention/test_fmha_v2_prefill.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Optional, Tuple, Union

import flashinfer

from flashinfer.prefill import fmha_v2_prefill_deepseek
from tests.utils_fp8 import to_float8
from flashinfer.utils import is_sm12x_supported, is_sm120a_supported
Expand Down Expand Up @@ -827,8 +828,6 @@ def test_trtllm_fmha_v2_prefill(
pos_encoding_mode: str,
save_softmax_stats: bool,
) -> None:
if dtype == torch.float8_e4m3fn:
pytest.skip("FP8 (e4m3) FMHA v2 kernels are known to hang on SM90")
run_trtllm_fmha_v2_prefill_case(
input_layout=input_layout,
batch_size=batch_size,
Expand Down
Loading