Add DeepSeek V4 sparse MLA TRTLLM-GEN kernels#3269
Conversation
|
Note Reviews pausedIt looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the Use the following commands to manage reviews:
Use the checkboxes below for quick actions:
📝 WalkthroughWalkthroughThis PR adds DeepSeek V4 sparse MLA decoding support to TRTLLM-GEN. Changes span sparse-MLA type system, CUDA kernel updates, kernel selection/hashing, launcher wiring, Python API with validation, and comprehensive tests covering variable-length scenarios and multiple tensor layouts. ChangesSparse MLA Feature Implementation
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested labels
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip 💬 Introducing Slack Agent: The best way for teams to turn conversations into code.Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.
Built for teams:
One agent for your entire SDLC. Right inside Slack. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
Please don't merge this PR until #3259 is merged. |
|
/bot run |
There was a problem hiding this comment.
Code Review
This pull request adds support for DeepSeek V4 sparse MLA kernels, enabling efficient attention computation with dynamic top-k lengths. The changes include updates to the kernel launchers, parameter structures, and the addition of a new public API, trtllm_batch_decode_sparse_mla_dsv4, along with comprehensive test coverage. I have reviewed the changes and suggest replacing the assert statements used for input validation in the public API with explicit TypeError raises to ensure robust validation, as assertions can be disabled in production environments.
There was a problem hiding this comment.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)
157-205:⚠️ Potential issue | 🟠 Major | 🏗️ Heavy liftDisambiguate dynamic-page kernels from static 128-page kernels in the hash.
selectNumTokensPerPage()rewrites every eligible paged decode withpage_size >= 128to the sentinel key128, buthashID()still hashes onlynumTokensPerPageLog2. That makes a dynamic 256/512/... request indistinguishable from a real static-128 kernel, so we can silently load the wrong cubin or hit metadata hash conflicts when both kernel families are present. Please encodemDynamicNumTokensPerPagein the kernel hash, or reserve a sentinel that can never overlap a real page size.Also applies to: 384-401, 900-935
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 157 - 205, hashID() currently only uses numTokensPerPageLog2 so dynamic paged kernels rewritten to sentinel 128 by selectNumTokensPerPage() collide with real static-128 kernels; update hashID (and the other affected hash builders around the same logic) to include the dynamic-page flag (mDynamicNumTokensPerPage or equivalent) into the key so dynamic vs static-128 are distinguishable: locate hashID and the other hash construction sites that use numTokensPerPageLog2 and add a dedicated bit field (or a reserved sentinel bit) for mDynamicNumTokensPerPage (or a boolean computed from it) when packing the uint64_t key, ensuring the chosen bit position does not overlap existing fields (and update any comments documenting bit layout).
🧹 Nitpick comments (5)
flashinfer/mla/_core.py (2)
549-553: 💤 Low valueUnreachable
seq_lens is Nonecheck.The signature declares
seq_lens: torch.Tensor(non-Optional, line 443), so this branch is effectively dead code. Either drop it, or change the parameter typing toOptional[torch.Tensor]ifNoneis genuinely intended to be supported.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 549 - 553, The check for "if seq_lens is None" is unreachable because the function signature declares seq_lens: torch.Tensor (non-Optional); either remove the dead branch or make the parameter Optional to reflect intended behavior. Fix by either (A) deleting the if-block and its ValueError if callers always pass a tensor, or (B) change the parameter annotation to Optional[torch.Tensor] (and import Optional) and keep the runtime check; reference the seq_lens parameter in the relevant function in mla/_core.py to apply the chosen change consistently.
561-565: 💤 Low value
torch.any(...).item()forces a CUDA→CPU sync on every decode call.This bound check synchronizes the stream on every invocation, which can hurt decode throughput when the kernel is otherwise designed to be PDL-friendly. If this is a defensive validation rather than a hard contract, consider running it only when explicit input validation is enabled (e.g., via an env flag or debug guard) so the hot path stays async.
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@flashinfer/mla/_core.py` around lines 561 - 565, The torch.any(...).item() call forces a CUDA→CPU sync in the hot path (the check comparing seq_lens and q_lens) — move this validation out of the async decode path and run it only when explicit input validation/debugging is enabled: wrap the existing comparison and ValueError raise (the block referencing seq_lens, q_lens and torch.any(...).item()) behind a debug/env guard (e.g., an opt-in flag like FLASHINFER_VALIDATE_INPUTS or a module-level debug boolean) so normal decoding avoids the synchronous .item() call; keep the same error message and logic under that guard so behavior is unchanged when validation is turned on.csrc/fmhaReduction.cu (1)
405-406: ⚡ Quick winAvoid magic-number coupling to
TrtllmGenSparseMlaType.
kernelMeta.mSparseAttn == 2implicitly assumesTrtllmGenSparseMlaType::DynamicTokenSparse == 2. If the enum is reordered or extended, this check silently breaks. Cast to the enum or use the existingisDynamicTokenSparseMlapredicate fromfmhaRunnerParams.hso the relationship is checked by the type system.- bool const supportsVarSparseMlaTopKLens = - kernelMeta.mSparseAttn == 2 && kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512; + bool const supportsVarSparseMlaTopKLens = + isDynamicTokenSparseMla(static_cast<TrtllmGenSparseMlaType>(kernelMeta.mSparseAttn)) && + kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512;🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/fmhaReduction.cu` around lines 405 - 406, The boolean supportsVarSparseMlaTopKLens currently compares kernelMeta.mSparseAttn to the magic number 2; change this to use the enum or predicate instead—replace the check "kernelMeta.mSparseAttn == 2" with a type-safe check using TrtllmGenSparseMlaType (casted) or, preferably, call the existing isDynamicTokenSparseMla(...) predicate from fmhaRunnerParams.h so the condition becomes isDynamicTokenSparseMla(kernelMeta.mSparseAttn) && kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512; keep the rest of the expression and the variable name supportsVarSparseMlaTopKLens unchanged.include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)
334-339: 💤 Low valueRedundant
mSparseMlaboolean duplicatesmSparseMlaTypestate.
mSparseMlais fully derivable fromisSparseMla(mSparseMlaType), and keeping both fields as the source of truth invites the two from drifting (e.g., a future call site sets one and forgets the other). Since you are already touching the runner-params surface here, consider replacing the field with a small helper accessor and updating the call sites that readmSparseMlato callisSparseMla(mSparseMlaType)directly.🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h` around lines 334 - 339, Remove the redundant member mSparseMla and stop storing sparse state separately; derive it from mSparseMlaType instead. Add a small const accessor on the class (e.g., isSparseMla() const) that returns isSparseMla(mSparseMlaType), keep mSparseMlaType and mSparseMlaTopK as-is, and update all call sites that read mSparseMla to call the new accessor (or call isSparseMla(mSparseMlaType) directly). Ensure no assignments to the removed mSparseMla remain.csrc/trtllm_fmha_kernel_launcher.cu (1)
785-790: 💤 Low valueInconsistent
is_4bitargument for sliding-window head-dim computation.
head_dim_kchecksis_4bit(kv_data_type)whilehead_dim_swis computed againstis_4bit(q_data_type). These happen to be equivalent here because of the dtype equality checks at lines 759-761, but reading the head-dim of the SWA tensor against the query's dtype is misleading. Usedl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype())(orkv_data_typesince it's enforced equal) for consistency.- int const head_dim_sw = is_4bit(q_data_type) ? sliding_window_kv_cache.size(-1) * 2 - : sliding_window_kv_cache.size(-1); + int const head_dim_sw = is_4bit(kv_data_type) ? sliding_window_kv_cache.size(-1) * 2 + : sliding_window_kv_cache.size(-1);🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the rest with a brief reason, keep changes minimal, and validate. In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 785 - 790, The computation of head_dim_sw uses is_4bit(q_data_type) which is misleading; update the condition to check the sliding-window tensor's actual dtype (either use is_4bit(kv_data_type) since dtype equality is enforced, or call is_4bit(dl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype()))). Specifically, modify the head_dim_sw expression so it references sliding_window_kv_cache's dtype (or kv_data_type) with is_4bit instead of q_data_type to make the intent and calculation consistent (see head_dim_sw, is_4bit, sliding_window_kv_cache, kv_data_type).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.
Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 306-321: The mypy error comes from assigning different-length
tuples to out_shape; fix it by adding an explicit widened type annotation (e.g.
typing.Tuple[int, ...]) for out_shape (and similarly sparse_indices_prefix_shape
if needed) in the function scope so both the 3-tuple and 4-tuple assignments are
type-compatible; locate the assignments to out_shape in the _core.py branch that
handles dense vs sparse queries (symbols: out_shape,
sparse_indices_prefix_shape, query_flat) and change their declaration to use
Tuple[int, ...] from typing.
- Around line 246-253: The device check in the call to check_shape_dtype_device
for cum_seq_lens_q is tautological because it passes cum_seq_lens_q.device as
the expected device; change this to compare against the caller's reference
device (for example use topk_lens.device or the query tensor's device) so
cum_seq_lens_q is validated to live on the same device as the other tensors, or
remove the device argument entirely if no cross-tensor device validation is
needed; update the call site where cum_seq_lens_q is validated (the call to
check_shape_dtype_device) to use the chosen reference device instead of
cum_seq_lens_q.device.
In `@tests/attention/test_trtllm_gen_sparse_mla_dsv4.py`:
- Around line 727-731: The _skip_unless_sm100_or_sm103 helper currently checks
only major CUDA compute capability 10; change it to require exact (10,0) or
(10,3) tuples returned by get_compute_capability(torch.device("cuda")) so only
SM100 or SM103 pass. In other words, update the conditional in
_skip_unless_sm100_or_sm103 to compare the full tuple against (10, 0) and (10,
3) (using the existing get_compute_capability call) and call pytest.skip when
the device tuple is not one of those two values.
---
Outside diff comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 157-205: hashID() currently only uses numTokensPerPageLog2 so
dynamic paged kernels rewritten to sentinel 128 by selectNumTokensPerPage()
collide with real static-128 kernels; update hashID (and the other affected hash
builders around the same logic) to include the dynamic-page flag
(mDynamicNumTokensPerPage or equivalent) into the key so dynamic vs static-128
are distinguishable: locate hashID and the other hash construction sites that
use numTokensPerPageLog2 and add a dedicated bit field (or a reserved sentinel
bit) for mDynamicNumTokensPerPage (or a boolean computed from it) when packing
the uint64_t key, ensuring the chosen bit position does not overlap existing
fields (and update any comments documenting bit layout).
---
Nitpick comments:
In `@csrc/fmhaReduction.cu`:
- Around line 405-406: The boolean supportsVarSparseMlaTopKLens currently
compares kernelMeta.mSparseAttn to the magic number 2; change this to use the
enum or predicate instead—replace the check "kernelMeta.mSparseAttn == 2" with a
type-safe check using TrtllmGenSparseMlaType (casted) or, preferably, call the
existing isDynamicTokenSparseMla(...) predicate from fmhaRunnerParams.h so the
condition becomes isDynamicTokenSparseMla(kernelMeta.mSparseAttn) &&
kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512; keep the rest of
the expression and the variable name supportsVarSparseMlaTopKLens unchanged.
In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 785-790: The computation of head_dim_sw uses is_4bit(q_data_type)
which is misleading; update the condition to check the sliding-window tensor's
actual dtype (either use is_4bit(kv_data_type) since dtype equality is enforced,
or call is_4bit(dl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype()))).
Specifically, modify the head_dim_sw expression so it references
sliding_window_kv_cache's dtype (or kv_data_type) with is_4bit instead of
q_data_type to make the intent and calculation consistent (see head_dim_sw,
is_4bit, sliding_window_kv_cache, kv_data_type).
In `@flashinfer/mla/_core.py`:
- Around line 549-553: The check for "if seq_lens is None" is unreachable
because the function signature declares seq_lens: torch.Tensor (non-Optional);
either remove the dead branch or make the parameter Optional to reflect intended
behavior. Fix by either (A) deleting the if-block and its ValueError if callers
always pass a tensor, or (B) change the parameter annotation to
Optional[torch.Tensor] (and import Optional) and keep the runtime check;
reference the seq_lens parameter in the relevant function in mla/_core.py to
apply the chosen change consistently.
- Around line 561-565: The torch.any(...).item() call forces a CUDA→CPU sync in
the hot path (the check comparing seq_lens and q_lens) — move this validation
out of the async decode path and run it only when explicit input
validation/debugging is enabled: wrap the existing comparison and ValueError
raise (the block referencing seq_lens, q_lens and torch.any(...).item()) behind
a debug/env guard (e.g., an opt-in flag like FLASHINFER_VALIDATE_INPUTS or a
module-level debug boolean) so normal decoding avoids the synchronous .item()
call; keep the same error message and logic under that guard so behavior is
unchanged when validation is turned on.
In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h`:
- Around line 334-339: Remove the redundant member mSparseMla and stop storing
sparse state separately; derive it from mSparseMlaType instead. Add a small
const accessor on the class (e.g., isSparseMla() const) that returns
isSparseMla(mSparseMlaType), keep mSparseMlaType and mSparseMlaTopK as-is, and
update all call sites that read mSparseMla to call the new accessor (or call
isSparseMla(mSparseMlaType) directly). Ensure no assignments to the removed
mSparseMla remain.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: defaults
Review profile: CHILL
Plan: Pro
Run ID: 454a6d85-dba9-42db-ac45-254e5f2a4cba
📥 Commits
Reviewing files that changed from the base of the PR and between 2fe28ca and a78471cf62219dde9bf062065bf06f0a7c94ea87.
📒 Files selected for processing (13)
.gitignorecsrc/fmhaReduction.cucsrc/trtllm_fmha_kernel_launcher.cuflashinfer/artifacts.pyflashinfer/decode.pyflashinfer/jit/env.pyflashinfer/mla/_core.pyinclude/flashinfer/trtllm/fmha/fmhaKernels.cuhinclude/flashinfer/trtllm/fmha/fmhaRunnerParams.hinclude/flashinfer/trtllm/fmha/kernelParams.htests/attention/test_attention_sink_blackwell.pytests/attention/test_trtllm_gen_attention.pytests/attention/test_trtllm_gen_sparse_mla_dsv4.py
f80b7d2 to
b351a08
Compare
|
/bot run |
b351a08 to
255870c
Compare
|
/bot run |
|
@saltyminty @qsang-nv please help review this. Thanks! |
Add the TRTLLM-GEN dynamic sparse MLA launch path for DeepSeek V4, including the Python wrapper, sparse/SWA metadata plumbing, FlashMLA-style coverage for BF16 and per-tensor FP8 cases, updated remote FMHA cubin metadata, and an ignore rule for local generated cubins. Tests: python -m py_compile flashinfer/artifacts.py Tests: python -m ruff check flashinfer/artifacts.py Tests: FLASHINFER_CUBINS_REPOSITORY=<private URM repository> FLASHINFER_CUBIN_DIR unset FLASHINFER_CUBIN_CHECKSUM_DISABLED unset FLASHINFER_DISABLE_VERSION_CHECK=1 FLASHINFER_WORKSPACE_BASE=/home/scratch.perkzz_gpu/dpsv4/flashinfer/.jit_dsv4_remote_artifacts CUDA_LAUNCH_BLOCKING=1 python -m pytest -q --tb=short tests/attention/test_trtllm_gen_sparse_mla_dsv4.py
255870c to
f578d94
Compare
|
/bot run |
|
/bot run |
| query.device, | ||
| cum_seq_lens_q, | ||
| ) | ||
| if normalized_sparse_lens.numel() > 0: |
There was a problem hiding this comment.
The sparse_topk_lens min/max check in _check_dsv4_sparse_mla_inputs does two .item() D2H syncs per call. The sibling seq_lens < q_lens check is already gated by _validate_dsv4_sync_checks() — is the gate intentionally omitted here, or a miss?
| return 512 if h_q == 64 else 1024 | ||
|
|
||
|
|
||
| def gen_testcase() -> tuple[RawTestParamForDecode, ...]: |
There was a problem hiding this comment.
All add_case(...) calls in gen_testcase() hardcode is_varlen=True, so the dense 4D-query branch in _check_dsv4_sparse_mla_inputs (the else arm with query.ndim == 4, out_shape = (batch_size, q_len_per_request, num_heads, 512)) and the not is_varlen arm of _make_q_lens are never exercised. Could you add at least one is_varlen=False case so this path has coverage?
Summary
Tests
python -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_sparse_mla_dsv4.py: 57 passedpython -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_attention_sink_blackwell.py: 144 passedpython -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_mla.py: 7686 passed, 12672 deselectedpython -m pytest -q --tb=short -n 8 -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_attention.py: 75736 passed, 30800 skippedSummary by CodeRabbit
New Features
Tests
Chores