Skip to content

Commit f578d94

Browse files
committed
Address sparse MLA review comments
1 parent 3171596 commit f578d94

7 files changed

Lines changed: 140 additions & 104 deletions

File tree

csrc/fmhaReduction.cu

Lines changed: 36 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,7 @@ template <int32_t TileSizePerCtaQ, int32_t HeadDimPerCta, bool IsE4m3Bmm, typena
3636
__global__ void __launch_bounds__(NumThreadsPerCta, 2)
3737
fmhaReductionKernel(KernelParams const params, bool isTokenSparse, bool groupsTokensHeadsQ,
3838
bool supportsVarSparseMlaTopKLens, int32_t numCtasForReduction,
39-
int32_t numCtasForAllHeads, int32_t headDimV,
40-
int32_t numHeadDimCtasV) {
39+
int32_t numCtasForAllHeads, int32_t headDimV, int32_t numHeadDimCtasV) {
4140
// clang-format off
4241
// The shape of partialO buffer: [batchSize, numHeadCtas, numCtasQ, numCtasKv, TileSizePerCtaQ, headDimPerCta].
4342
// The shape of final O buffer: [batchSize, numCtasQ, numHeadsQ, headDim].
@@ -285,37 +284,37 @@ __global__ void __launch_bounds__(NumThreadsPerCta, 2)
285284

286285
////////////////////////////////////////////////////////////////////////////////////////////////////
287286

288-
#define SELECT_FMHA_REDUCTION_KERNEL(TileSizePerCtaQ, HeadDimPerCta) \
289-
if (kernelMeta.mDataTypeQ == DATA_TYPE_E4M3) { \
290-
if (kernelMeta.mDataTypeO == DATA_TYPE_E4M3) { \
291-
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, __nv_fp8_e4m3, half>; \
292-
} else if (kernelMeta.mDataTypeO == DATA_TYPE_FP16) { \
293-
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, half, half>; \
294-
} else if (kernelMeta.mDataTypeO == DATA_TYPE_BF16) { \
295-
kernel = \
296-
&fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, __nv_bfloat16, __nv_bfloat16>; \
297-
} else { \
298-
FLASHINFER_CHECK(false, "Not implemented"); \
299-
} \
300-
} else { \
301-
FLASHINFER_CHECK(kernelMeta.mDataTypeQ == kernelMeta.mDataTypeO, "Not implemented"); \
302-
if (kernelMeta.mDataTypeQ == DATA_TYPE_FP16) { \
303-
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, false, half, half>; \
304-
} else if (kernelMeta.mDataTypeQ == DATA_TYPE_BF16) { \
305-
kernel = \
306-
&fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, false, __nv_bfloat16, __nv_bfloat16>; \
307-
} else { \
308-
FLASHINFER_CHECK(false, "Not implemented"); \
309-
} \
287+
#define SELECT_FMHA_REDUCTION_KERNEL(TileSizePerCtaQ, HeadDimPerCta) \
288+
if (kernelMeta.mDataTypeQ == DATA_TYPE_E4M3) { \
289+
if (kernelMeta.mDataTypeO == DATA_TYPE_E4M3) { \
290+
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, __nv_fp8_e4m3, half>; \
291+
} else if (kernelMeta.mDataTypeO == DATA_TYPE_FP16) { \
292+
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, half, half>; \
293+
} else if (kernelMeta.mDataTypeO == DATA_TYPE_BF16) { \
294+
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, true, __nv_bfloat16, \
295+
__nv_bfloat16>; \
296+
} else { \
297+
FLASHINFER_CHECK(false, "Not implemented"); \
298+
} \
299+
} else { \
300+
FLASHINFER_CHECK(kernelMeta.mDataTypeQ == kernelMeta.mDataTypeO, "Not implemented"); \
301+
if (kernelMeta.mDataTypeQ == DATA_TYPE_FP16) { \
302+
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, false, half, half>; \
303+
} else if (kernelMeta.mDataTypeQ == DATA_TYPE_BF16) { \
304+
kernel = &fmhaReductionKernel<TileSizePerCtaQ, HeadDimPerCta, false, __nv_bfloat16, \
305+
__nv_bfloat16>; \
306+
} else { \
307+
FLASHINFER_CHECK(false, "Not implemented"); \
308+
} \
310309
}
311310

312-
#define SELECT_FMHA_REDUCTION_KERNEL_WITH_HEAD_DIM_PER_CTA(HeadDimPerCta) \
313-
if (kernelMeta.mTileSizeQ == 64) { \
314-
SELECT_FMHA_REDUCTION_KERNEL(64, HeadDimPerCta); \
315-
} else if (kernelMeta.mTileSizeQ == 128) { \
316-
SELECT_FMHA_REDUCTION_KERNEL(128, HeadDimPerCta); \
317-
} else { \
318-
FLASHINFER_CHECK(false, "Not implemented"); \
311+
#define SELECT_FMHA_REDUCTION_KERNEL_WITH_HEAD_DIM_PER_CTA(HeadDimPerCta) \
312+
if (kernelMeta.mTileSizeQ == 64) { \
313+
SELECT_FMHA_REDUCTION_KERNEL(64, HeadDimPerCta); \
314+
} else if (kernelMeta.mTileSizeQ == 128) { \
315+
SELECT_FMHA_REDUCTION_KERNEL(128, HeadDimPerCta); \
316+
} else { \
317+
FLASHINFER_CHECK(false, "Not implemented"); \
319318
}
320319

321320
////////////////////////////////////////////////////////////////////////////////////////////////////
@@ -403,7 +402,12 @@ void runFmhaReduction(TllmGenFmhaKernelMetaInfo const& kernelMeta, KernelParams
403402

404403
// Launch the kernel.
405404
bool const supportsVarSparseMlaTopKLens =
406-
kernelMeta.mSparseAttn == 2 && kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512;
405+
isDynamicTokenSparseMla(static_cast<TrtllmGenSparseMlaType>(kernelMeta.mSparseAttn)) &&
406+
kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512;
407+
if (supportsVarSparseMlaTopKLens) {
408+
FLASHINFER_CHECK(params.ptrSparseMlaTopKLens != nullptr,
409+
"Dynamic sparse MLA reduction requires sparseMlaTopkLengths.");
410+
}
407411
cudaLaunchKernelEx(&config, kernel, params, kernelMeta.mSparseAttn != 0,
408412
kernelMeta.mGroupsTokensHeadsQ, supportsVarSparseMlaTopKLens,
409413
numCtasForReduction, numCtasForAllHeads, kernelMeta.mHeadDimV,

csrc/trtllm_fmha_kernel_launcher.cu

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,11 +169,10 @@ void trtllm_paged_attention_launcher(
169169

170170
// The sparse MLA parameters.
171171
runner_params.mSparseMlaType =
172-
sparse_mla_top_k <= 0 ? TrtllmGenSparseMlaType::None
173-
: (sparse_mla_top_k_lens != nullptr
174-
? TrtllmGenSparseMlaType::DynamicTokenSparse
175-
: TrtllmGenSparseMlaType::StaticTokenSparse);
176-
runner_params.mSparseMla = isSparseMla(runner_params.mSparseMlaType);
172+
sparse_mla_top_k <= 0
173+
? TrtllmGenSparseMlaType::None
174+
: (sparse_mla_top_k_lens != nullptr ? TrtllmGenSparseMlaType::DynamicTokenSparse
175+
: TrtllmGenSparseMlaType::StaticTokenSparse);
177176
runner_params.mSparseMlaTopK = sparse_mla_top_k;
178177
runner_params.sparseMlaTopKLensPtr = sparse_mla_top_k_lens;
179178
runner_params.mHasSlidingWindowKvPool = has_sliding_window_kv_pool;
@@ -737,8 +736,8 @@ void trtllm_paged_attention_decode_sparse_mla_dsv4(
737736
TensorView sliding_window_kv_cache, TensorView workspace_buffer, TensorView sparse_indices,
738737
TensorView seq_lens, TensorView sparse_mla_top_k_lens, Variant<double, ffi::Tensor> bmm1_scale,
739738
Variant<double, ffi::Tensor> bmm2_scale, int64_t batch_size, int64_t max_q_len,
740-
int64_t sm_count, bool enable_pdl, int64_t workspace_size,
741-
Optional<TensorView> attention_sinks, Optional<TensorView> cum_seq_lens_q) {
739+
int64_t sm_count, bool enable_pdl, int64_t workspace_size, Optional<TensorView> attention_sinks,
740+
Optional<TensorView> cum_seq_lens_q) {
742741
auto q_data_type = dl_dtype_to_tllm_data_type(query.dtype());
743742
auto kv_data_type = dl_dtype_to_tllm_data_type(primary_kv_cache.dtype());
744743
auto o_data_type = dl_dtype_to_tllm_data_type(out.dtype());
@@ -777,6 +776,10 @@ void trtllm_paged_attention_decode_sparse_mla_dsv4(
777776
if (is_varlen_q) {
778777
TVM_FFI_ICHECK_EQ(cum_seq_lens_q.value().ndim(), 1);
779778
TVM_FFI_ICHECK_EQ(cum_seq_lens_q.value().dtype(), dl_int32);
779+
TVM_FFI_ICHECK_EQ(cum_seq_lens_q.value().device().device_type, query.device().device_type)
780+
<< "cum_seq_lens_q must be on the same device as query";
781+
TVM_FFI_ICHECK_EQ(cum_seq_lens_q.value().device().device_id, query.device().device_id)
782+
<< "cum_seq_lens_q must be on the same device as query";
780783
TVM_FFI_ICHECK_EQ(cum_seq_lens_q.value().size(0), batch_size + 1);
781784
} else {
782785
TVM_FFI_ICHECK_EQ(sum_seq_q, batch_size * max_q_len);
@@ -785,8 +788,8 @@ void trtllm_paged_attention_decode_sparse_mla_dsv4(
785788
int const head_dim_q = is_4bit(q_data_type) ? query.size(-1) * 2 : query.size(-1);
786789
int const head_dim_k =
787790
is_4bit(kv_data_type) ? primary_kv_cache.size(-1) * 2 : primary_kv_cache.size(-1);
788-
int const head_dim_sw = is_4bit(q_data_type) ? sliding_window_kv_cache.size(-1) * 2
789-
: sliding_window_kv_cache.size(-1);
791+
int const head_dim_sw = is_4bit(kv_data_type) ? sliding_window_kv_cache.size(-1) * 2
792+
: sliding_window_kv_cache.size(-1);
790793
int const head_dim_o = is_4bit(o_data_type) ? out.size(-1) * 2 : out.size(-1);
791794
TVM_FFI_ICHECK_EQ(head_dim_q, 512);
792795
TVM_FFI_ICHECK_EQ(head_dim_k, 512);
@@ -836,9 +839,9 @@ void trtllm_paged_attention_decode_sparse_mla_dsv4(
836839
out.data_ptr(), /*out_scale_factor=*/nullptr, query.data_ptr(), primary_kv_cache.data_ptr(),
837840
primary_kv_cache.data_ptr(), workspace_buffer.data_ptr(),
838841
static_cast<int*>(sparse_indices.data_ptr()), /*k_block_scales_ptr=*/nullptr,
839-
/*v_block_scales_ptr=*/nullptr, static_cast<int*>(seq_lens.data_ptr()),
840-
cum_seq_lens_q_ptr, /*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type,
841-
kv_data_type, o_data_type, TllmPagedAttentionMode::ForGen, batch_size, max_q_len,
842+
/*v_block_scales_ptr=*/nullptr, static_cast<int*>(seq_lens.data_ptr()), cum_seq_lens_q_ptr,
843+
/*cum_seq_lens_kv=*/nullptr, attention_sinks_ptr, q_data_type, kv_data_type, o_data_type,
844+
TllmPagedAttentionMode::ForGen, batch_size, max_q_len,
842845
/*max_kv_len=*/sparse_mla_top_k, num_pages_in_mem_pool, num_qo_heads, num_kv_heads,
843846
head_dim_q, head_dim_o, page_size, q_stride_tokens, q_stride_heads, kv_stride_keys_values,
844847
kv_stride_heads, kv_stride_batch, /*max_num_blocks_per_seq=*/sparse_mla_top_k,

flashinfer/mla/_core.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from dataclasses import dataclass
1818
import functools
1919
import math
20+
import os
2021
from typing import List, Literal, Optional, Tuple, Union, overload
2122

2223
import torch
@@ -233,10 +234,13 @@ def _normalize_dsv4_topk_lens(
233234
q_len_per_request: int,
234235
sum_seq_q: int,
235236
name: str,
237+
device: torch.device,
236238
cum_seq_lens_q: Optional[torch.Tensor] = None,
237239
) -> torch.Tensor:
238240
if topk_lens.dtype != torch.int32:
239241
raise ValueError(f"{name} must have dtype torch.int32, got {topk_lens.dtype}")
242+
if topk_lens.device != device:
243+
raise ValueError(f"{name} must be on device {device}, got {topk_lens.device}")
240244
if topk_lens.ndim != 1:
241245
raise ValueError(f"Expected flattened {name}.ndim == 1, got {topk_lens.ndim}")
242246
if topk_lens.size(0) != sum_seq_q:
@@ -249,12 +253,16 @@ def _normalize_dsv4_topk_lens(
249253
cum_seq_lens_q,
250254
(batch_size + 1,),
251255
torch.int32,
252-
cum_seq_lens_q.device,
256+
device,
253257
"cum_seq_lens_q",
254258
)
255259
return topk_lens
256260

257261

262+
def _validate_dsv4_sync_checks() -> bool:
263+
return os.environ.get("FLASHINFER_VALIDATE_INPUTS", "0") not in ("0", "")
264+
265+
258266
def _check_dsv4_sparse_mla_inputs(
259267
query: torch.Tensor,
260268
swa_kv_cache: torch.Tensor,
@@ -277,13 +285,16 @@ def _check_dsv4_sparse_mla_inputs(
277285
Optional[torch.Tensor],
278286
]:
279287
is_varlen_q = cum_seq_lens_q is not None
288+
out_shape: Tuple[int, ...]
289+
sparse_indices_prefix_shape: Tuple[int, ...]
280290
if is_varlen_q:
281291
if query.ndim != 3:
282292
raise ValueError(
283293
"Expected query.ndim == 3 when cum_seq_lens_q is provided, "
284294
f"got {query.ndim}"
285295
)
286-
assert cum_seq_lens_q is not None
296+
if cum_seq_lens_q is None:
297+
raise ValueError("cum_seq_lens_q is required for varlen query input")
287298
if cum_seq_lens_q.dtype != torch.int32:
288299
raise ValueError(
289300
f"cum_seq_lens_q must have dtype torch.int32, got {cum_seq_lens_q.dtype}"
@@ -297,6 +308,11 @@ def _check_dsv4_sparse_mla_inputs(
297308
raise ValueError(
298309
f"Expected cum_seq_lens_q.numel() >= 2, got {cum_seq_lens_q.numel()}"
299310
)
311+
if cum_seq_lens_q.device != query.device:
312+
raise ValueError(
313+
f"cum_seq_lens_q must be on query device {query.device}, "
314+
f"got {cum_seq_lens_q.device}"
315+
)
300316
sum_seq_q, num_heads, head_dim = query.shape
301317
if max_q_len is None:
302318
max_q_len = int((cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]).max().item())
@@ -331,7 +347,11 @@ def _check_dsv4_sparse_mla_inputs(
331347
if num_heads not in (64, 128):
332348
raise ValueError(f"Expected 64 or 128 query heads, got {num_heads}")
333349

334-
if sparse_indices is None or compressed_kv_cache is None or sparse_topk_lens is None:
350+
if (
351+
sparse_indices is None
352+
or compressed_kv_cache is None
353+
or sparse_topk_lens is None
354+
):
335355
raise ValueError(
336356
"sparse_indices, compressed_kv_cache, and sparse_topk_lens are required"
337357
)
@@ -346,8 +366,7 @@ def _check_dsv4_sparse_mla_inputs(
346366
)
347367
if sparse_indices.ndim != 2:
348368
raise ValueError(
349-
"Expected flattened sparse_indices.ndim == 2, got "
350-
f"{sparse_indices.ndim}"
369+
f"Expected flattened sparse_indices.ndim == 2, got {sparse_indices.ndim}"
351370
)
352371
if sparse_indices.shape[:-1] != sparse_indices_prefix_shape:
353372
raise ValueError(
@@ -396,8 +415,9 @@ def _check_dsv4_sparse_mla_inputs(
396415
q_len_per_request,
397416
sum_seq_q,
398417
"sparse_topk_lens",
418+
query.device,
399419
cum_seq_lens_q,
400-
).to(query.device)
420+
)
401421
if normalized_sparse_lens.numel() > 0:
402422
min_sparse_len = int(normalized_sparse_lens.min().item())
403423
max_sparse_len = int(normalized_sparse_lens.max().item())
@@ -419,9 +439,6 @@ def _check_dsv4_sparse_mla_inputs(
419439
sinks, (num_heads,), torch.float32, query.device, "sinks"
420440
)
421441

422-
if cum_seq_lens_q is not None:
423-
cum_seq_lens_q = cum_seq_lens_q.to(query.device)
424-
425442
return (
426443
swa_kv_cache,
427444
compressed_kv_cache,
@@ -515,10 +532,12 @@ def trtllm_batch_decode_sparse_mla_dsv4(
515532
if enable_pdl is None:
516533
enable_pdl = device_support_pdl(query.device)
517534
if isinstance(bmm1_scale, torch.Tensor):
518-
assert bmm1_scale.dtype == torch.float32
535+
if bmm1_scale.dtype != torch.float32:
536+
raise TypeError("bmm1_scale tensor must have dtype torch.float32")
519537
bmm1_scale = bmm1_scale * log2e
520538
if isinstance(bmm2_scale, torch.Tensor):
521-
assert bmm2_scale.dtype == torch.float32
539+
if bmm2_scale.dtype != torch.float32:
540+
raise TypeError("bmm2_scale tensor must have dtype torch.float32")
522541

523542
(
524543
swa_kv_cache,
@@ -529,37 +548,30 @@ def trtllm_batch_decode_sparse_mla_dsv4(
529548
query_flat,
530549
expected_out_shape,
531550
cum_seq_lens_q,
532-
) = (
533-
_check_dsv4_sparse_mla_inputs(
534-
query,
535-
swa_kv_cache,
536-
sparse_indices,
537-
compressed_kv_cache,
538-
sparse_topk_lens,
539-
out,
540-
sinks,
541-
kv_layout,
542-
cum_seq_lens_q,
543-
max_q_len,
544-
)
551+
) = _check_dsv4_sparse_mla_inputs(
552+
query,
553+
swa_kv_cache,
554+
sparse_indices,
555+
compressed_kv_cache,
556+
sparse_topk_lens,
557+
out,
558+
sinks,
559+
kv_layout,
560+
cum_seq_lens_q,
561+
max_q_len,
545562
)
546563

547564
if out is None:
548565
out = torch.empty(expected_out_shape, dtype=torch.bfloat16, device=query.device)
549566

550-
if seq_lens is None:
551-
raise ValueError(
552-
"seq_lens is required for DeepSeek V4 sparse MLA because TRTLLM-GEN "
553-
"uses it to mask the fixed SWA-128 tile"
554-
)
555567
check_shape_dtype_device(
556568
seq_lens, (batch_size,), torch.int32, query.device, "seq_lens"
557569
)
558570
if cum_seq_lens_q is None:
559571
q_lens = seq_lens.new_full((batch_size,), q_len_per_request)
560572
else:
561573
q_lens = cum_seq_lens_q[1:] - cum_seq_lens_q[:-1]
562-
if torch.any(seq_lens < q_lens).item():
574+
if _validate_dsv4_sync_checks() and torch.any(seq_lens < q_lens).item():
563575
raise ValueError(
564576
"seq_lens must be greater than or equal to the per-request query "
565577
"lengths so TRTLLM-GEN can derive the SWA-128 valid window"
@@ -1120,10 +1132,12 @@ def trtllm_batch_decode_with_kv_cache_mla(
11201132
"trtllm-gen" if get_compute_capability(query.device)[0] == 10 else "xqa"
11211133
)
11221134
if isinstance(bmm1_scale, torch.Tensor):
1123-
assert bmm1_scale.dtype == torch.float32
1135+
if bmm1_scale.dtype != torch.float32:
1136+
raise TypeError("bmm1_scale tensor must have dtype torch.float32")
11241137
bmm1_scale = bmm1_scale * log2e
11251138
if isinstance(bmm2_scale, torch.Tensor):
1126-
assert bmm2_scale.dtype == torch.float32
1139+
if bmm2_scale.dtype != torch.float32:
1140+
raise TypeError("bmm2_scale tensor must have dtype torch.float32")
11271141
if backend == "xqa":
11281142
if not is_sm12x_supported(query.device):
11291143
raise ValueError(

0 commit comments

Comments
 (0)