Skip to content

Commit 9f7adfb

Browse files
authored
[Fmha] Add head_dim=512 support for trtllm attention kernels (#2959)
Add support for `head_dim=512` in the trtllm FMHA kernel selection. ### Changes - Add SDPA-based reference implementation for `head_dim > 256` in tests (FlashInfer FA2/FA3 kernels don't support `head_dim > 256`) - Add `test_trtllm_batch_prefill_head_dim_512` and `test_trtllm_batch_decode_head_dim_512` covering BF16, FP16, and FP8 dtypes ### Follow-up - NVFP4 coverage at `head_dim=512` is deferred to a follow-up PR. Signed-off-by: Duncan Moss <djm.moss@gmail.com>
1 parent 6ddbdb0 commit 9f7adfb

5 files changed

Lines changed: 295 additions & 12 deletions

File tree

csrc/fmhaReduction.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
8181
seqLenKv = seqLenKv - ((params.mMaxSeqLenQ - 1) - ctaIdxQ);
8282
// Consider sparseMlaTopK.
8383
if (sparseMla) {
84-
seqLenKv = min(seqLenKv, params.mSparseMlaTopK);
84+
seqLenKv = min(seqLenKv, params.mSparseAttnTopK);
8585
}
8686
// The actual number of CtasKv (TileSizeKv is always 128 for now).
8787
int32_t numCtasKv{min((seqLenKv + 127) / 128, params.mMaxNumCtasKv)};
@@ -361,7 +361,7 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
361361
}
362362

363363
// Launch the kernel.
364-
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseMla, numCtasForReduction,
364+
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseAttn != 0, numCtasForReduction,
365365
numCtasForAllHeads, numHeadDimCtasV);
366366
cudaError_t err = cudaGetLastError();
367367
FLASHINFER_CHECK(err == cudaSuccess, "Failed to launch kernel: ", cudaGetErrorString(err));

flashinfer/artifacts.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ class ArtifactPath:
135135
When compiling new cubins for backend directories, update the corresponding path.
136136
"""
137137

138-
TRTLLM_GEN_FMHA: str = "55bba55929d4093682e32d817bd11ffb0441c749/fmha/trtllm-gen/"
138+
TRTLLM_GEN_FMHA: str = "82f4c77d9cf83e3fcf105feda4ce3445100ab491/fmha/trtllm-gen/"
139139
TRTLLM_GEN_BMM: str = (
140140
"39a9d28268f43475a757d5700af135e1e58c9849/batched_gemm-5ee61af-2b9855b/"
141141
)
@@ -155,7 +155,7 @@ class CheckSumHash:
155155
"""
156156

157157
TRTLLM_GEN_FMHA: str = (
158-
"f2c0aad1e74391c4267a2f9a20ec819358b59e04588385cffb452ed341500b99"
158+
"56c95fbe5d1b5d0d9ded7706e1c0b7ebf0582d9cfd2f9382acd878b6b9d58c89"
159159
)
160160
TRTLLM_GEN_BMM: str = (
161161
"db06db7f36a2a9395a2041ff6ac016fe664874074413a2ed90797f91ef17e0f6"

include/flashinfer/trtllm/fmha/fmhaKernels.cuh

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ class TllmGenFmhaKernel {
191191
kernelMeta.mTileScheduler, kernelMeta.mMultiCtasKvMode,
192192
kernelMeta.mHeadDimPerCtaV, kernelMeta.mHeadDimQk, kernelMeta.mHeadDimV,
193193
kernelMeta.mTileSizeQ, kernelMeta.mTileSizeKv, kernelMeta.mNumTokensPerPage,
194-
kernelMeta.mReuseSmemKForV, kernelMeta.m2CtaMma, kernelMeta.mSparseMla,
194+
kernelMeta.mReuseSmemKForV, kernelMeta.m2CtaMma, kernelMeta.mSparseAttn != 0,
195195
kernelMeta.mSkipsSoftmaxWhenPossible);
196196
}
197197

@@ -789,6 +789,14 @@ class TllmGenFmhaKernel {
789789
selectGqGenerationKernel(params, selectKernelParams);
790790
}
791791

792+
// For headDimV > 256, set headDimPerCtaV to 256 for context and keepsMmaAbForGeneration
793+
// kernels. swapsMmaAbForGeneration has enough TMEM resources to hold the full headDimV.
794+
// Called for context and GQA generation; MLA sets headDimPerCtaV separately.
795+
if (params.mHeadDimV > 256 && !isMlaGenKernel(params) &&
796+
!isSwapsMmaAbForGenerationKernel(selectKernelParams.mKernelType)) {
797+
selectKernelParams.mHeadDimPerCtaV = 256;
798+
}
799+
792800
// Enable sliding window or chunked causal if the max kv sequence length exceeds attention
793801
// window size or chunked attention size. This is supported by causal-mask context kernels and
794802
// generation-phase kernels.

include/flashinfer/trtllm/fmha/kernelParams.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,13 +169,14 @@ struct KernelParams {
169169
float mScaleSfO;
170170
// Threshold to decide whether warp skips softmax ops
171171
float mSkipSoftmaxThresholdScaleFactor;
172+
// The sparse attention topK value. Must immediately follow mSkipSoftmaxThresholdScaleFactor
173+
// to match the GPU struct layout expected by trtllm-gen kernels (changed in a339772b).
174+
int32_t mSparseAttnTopK;
172175
// The start token index in SF tensor. Used for FP4 SF offset calculation in generation phase
173176
// kernel when inflight batching is enabled in TRT-LLM.
174177
int32_t mStartTokenIdxSfO;
175178
// The sum of sequence lengths for Q and K/V.
176179
int32_t mSumOfSeqLensQ, mSumOfSeqLensKv;
177-
// The sparseMla topK value.
178-
int32_t mSparseMlaTopK;
179180
// The flag to use block sparse attention.
180181
bool mUseBlockSparseAttention;
181182
// Whether the indices for K & V pages are shared as unified index.
@@ -854,7 +855,7 @@ struct KernelParams {
854855
// indices.
855856
FLASHINFER_CHECK(!options.mSparseMla || (options.mSparseMlaTopK % 4) == 0,
856857
"SparseMlaTopK must be a multiple of 4");
857-
params.mSparseMlaTopK = options.mSparseMlaTopK;
858+
params.mSparseAttnTopK = options.mSparseMlaTopK;
858859
// TODO: Integrate trtllm block-sparse attention kernels when needed.
859860
params.mUseBlockSparseAttention = false;
860861
// Whether the indices for K & V pages are shared as unified index (vLLM/FlashInfer).

0 commit comments

Comments
 (0)