Skip to content

Add DeepSeek V4 sparse MLA TRTLLM-GEN kernels#3269

Open
PerkzZheng wants to merge 2 commits into
flashinfer-ai:mainfrom
PerkzZheng:dsv4-sparse-mla-trtllm-gen
Open

Add DeepSeek V4 sparse MLA TRTLLM-GEN kernels#3269
PerkzZheng wants to merge 2 commits into
flashinfer-ai:mainfrom
PerkzZheng:dsv4-sparse-mla-trtllm-gen

Conversation

@PerkzZheng
Copy link
Copy Markdown
Contributor

@PerkzZheng PerkzZheng commented May 8, 2026

Summary

  • Add DeepSeek V4 sparse MLA TRTLLM-GEN decode support for BF16 and per-tensor FP8 paths.
  • Plumb SWA and compressed KV pools, concatenated sparse indices, and per-query sparse top-k lengths through FlashInfer.
  • Add DeepSeek V4 sparse MLA tests covering SWA-only and compressed top-k cases with variable Q/KV lengths.

Tests

  • python -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_sparse_mla_dsv4.py: 57 passed
  • python -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_attention_sink_blackwell.py: 144 passed
  • python -m pytest -q --tb=short -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_mla.py: 7686 passed, 12672 deselected
  • python -m pytest -q --tb=short -n 8 -k 'not xqa and not cute and not trtllm-native' tests/attention/test_trtllm_gen_attention.py: 75736 passed, 30800 skipped

Summary by CodeRabbit

  • New Features

    • Added DeepSeek V4 sparse MLA decode support with variable top-k behavior and sliding-window KV cache integration.
    • Enhanced kernel selection with dynamic token-per-page support for improved performance flexibility.
  • Tests

    • Added comprehensive test suite for DeepSeek V4 sparse MLA decode across multiple configurations.
  • Chores

    • Updated environment variable priority for CUBIN directory selection.
    • Added backward-compatibility alias for MLA decode function.

Review Change Stack

@coderabbitai
Copy link
Copy Markdown
Contributor

coderabbitai Bot commented May 8, 2026

Note

Reviews paused

It looks like this branch is under active development. To avoid overwhelming you with review comments due to an influx of new commits, CodeRabbit has automatically paused this review. You can configure this behavior by changing the reviews.auto_review.auto_pause_after_reviewed_commits setting.

Use the following commands to manage reviews:

  • @coderabbitai resume to resume automatic reviews.
  • @coderabbitai review to trigger a single review.

Use the checkboxes below for quick actions:

  • ▶️ Resume reviews
  • 🔍 Trigger review
📝 Walkthrough

Walkthrough

This PR adds DeepSeek V4 sparse MLA decoding support to TRTLLM-GEN. Changes span sparse-MLA type system, CUDA kernel updates, kernel selection/hashing, launcher wiring, Python API with validation, and comprehensive tests covering variable-length scenarios and multiple tensor layouts.

Changes

Sparse MLA Feature Implementation

Layer / File(s) Summary
Type Contracts & Sparse MLA Enum
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
Introduce TrtllmGenSparseMlaType enum (None/StaticTokenSparse/DynamicTokenSparse) with helper predicates; add slidingWindowKvPoolPtr, sparseMlaTopKLensPtr, mSparseMlaType, mHasSlidingWindowKvPool to runner params; add mDynamicNumTokensPerPage to select kernel params; add isSparseMla() query method.
Kernel Parameters & TMA Descriptors
include/flashinfer/trtllm/fmha/kernelParams.h
Update sparse-MLA gating to use options.isSparseMla(); conditionally build tmaKSlidingWindowKvPool_ descriptor when sliding-window KV pool is enabled; assign ptrSparseMlaTopKLens from runner params; update sparse-MLA top-K validation gating.
FMHA Reduction Kernel Implementation
csrc/fmhaReduction.cu
Remove compile-time HeadDim template parameter and add runtime headDimV/numHeadDimCtasV plus boolean flags (isTokenSparse, groupsTokensHeadsQ, supportsVarSparseMlaTopKLens); recompute work bounds with numHeadsQPerKvCta and headIdxO; handle early return when no valid tokens; update softmax/output offset computation and store indexing for grouped head layout; update kernel-selection macros for Q tile sizes 64/128; generalize runFmhaReduction launch configuration and cudaLaunchKernelEx call wiring.
Kernel Selection & Trait Hashing
include/flashinfer/trtllm/fmha/fmhaKernels.cuh
Rewrite hashID(...) to accept dynamicNumTokensPerPage and sparseMlaType (2-bit); relax numTokensPerPage validation to allow 0; update hash bitfield layout; add isDynamicNumTokensPerPageKernel(...) helper; normalize numTokensPerPage based on sparse MLA state and paging; expand isMlaGenKernel(...) with additional sparse-MLA head-dim pattern; gate multi-CTAs KV window, tile-size upgrade, MLA generation branching, and sliding-window/chunked-causal mask enabling on sparse-MLA state; adjust use-2CTA heuristic for multi-token MLA; update trait debug/info generation.
Paged Attention Launcher & DSV4 Entry
csrc/trtllm_fmha_kernel_launcher.cu
Extend trtllm_paged_attention_launcher signature to accept sliding_window_kv_pool, sparse_mla_top_k_lens, has_sliding_window_kv_pool; derive mSparseMlaType from sparse-MLA top-K configuration; set slidingWindowKvPoolPtr, sparseMlaTopKLensPtr, mHasSlidingWindowKvPool; update head-dimension validation to allow new MLA pairing; reorder skip-softmax parameters. Update trtllm_paged_attention_decode and trtllm_paged_attention_context to pass non-sparse defaults. Add new trtllm_paged_attention_decode_sparse_mla_dsv4 TVM entry with DeepSeek V4 constraint validation (BF16/FP8 queries, head dim=512, one KV head, topK multiple-of-4), optional varlen support, and generation-mode launch wiring.
Python Decode API & Input Validation
flashinfer/mla/_core.py
Add helpers: _normalize_dsv4_sparse_mla_kv_cache (HND/NHD layout normalization), _normalize_dsv4_topk_lens (flattened INT32 conversion and cum_seq_lens_q validation), _validate_dsv4_sync_checks (environment-gated validation). Introduce _check_dsv4_sparse_mla_inputs for comprehensive DSV4 validation (dtype/shape/device, fixed 128 SWA entries, sparse-indices flattening, top-k capacity). Implement trtllm_batch_decode_sparse_mla_dsv4 orchestration: defaults enable_pdl, converts scales to log2, allocates output, derives per-request query lengths, optionally validates seq_lens, loads TRTLLM-GEN op, dispatches DSV4 kernel. Tighten dtype checks in trtllm_batch_decode_with_kv_cache_mla to raise TypeError explicitly.
Module Exports, Artifacts & Environment
flashinfer/decode.py, flashinfer/jit/env.py, .gitignore
Re-export trtllm_batch_decode_sparse_mla_dsv4 alias from .mla in decode module. Change _get_cubin_dir() priority to check FLASHINFER_CUBIN_DIR first, then flashinfer_cubin package, then cache. Add local_cubins/ ignore rule.
Test Coverage & Validation
tests/attention/test_trtllm_gen_sparse_mla_dsv4.py
Add parameterized DSV4 decode tests across variable-length scenarios, dtypes (BF16/FP8), head sizes, KV layouts (HND/NHD). Generate deterministic test data with block tables, flattened KV caches, absolute/relative KV indices, optional top-k constraints. Implement TRTLLM and reference decoders: reference performs gather-by-indices, invalid-index masking, optional top-k masking, log-sum-exp normalization, attention-sink scaling. Test assertions compare TRTLLM output to reference on valid query subset; skip on missing cubins/kernels.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested labels

run-ci

Suggested reviewers

  • sricketts
  • aleozlx
  • yongwww
  • yzh119
  • cyx-6

Poem

🐰 A rabbit hops through sparse tokens bright,
Each deep-seek V4 MLA takes flight,
Cubins compiled, kernels aligned—
Sliding windows and top-K refined!
With validation and tests all tight,
The decode dance shines in the night. ✨

🚥 Pre-merge checks | ✅ 4 | ❌ 1

❌ Failed checks (1 warning)

Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 14.75% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (4 passed)
Check name Status Explanation
Title check ✅ Passed The title clearly and specifically describes the main change: adding DeepSeek V4 sparse MLA TRTLLM-GEN kernels, which aligns with the core objective and changes throughout the PR.
Description check ✅ Passed The description provides a clear summary of changes, lists test results demonstrating comprehensive testing, but lacks the pre-commit checks checklist completion and explicit related issues/reviewer notes from the template.
Linked Issues check ✅ Passed Check skipped because no linked issues were found for this pull request.
Out of Scope Changes check ✅ Passed Check skipped because no linked issues were found for this pull request.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing Touches
🧪 Generate unit tests (beta)
  • Create PR with unit tests

Tip

💬 Introducing Slack Agent: The best way for teams to turn conversations into code.

Slack Agent is built on CodeRabbit's deep understanding of your code, so your team can collaborate across the entire SDLC without losing context.

  • Generate code and open pull requests
  • Plan features and break down work
  • Investigate incidents and troubleshoot customer tickets together
  • Automate recurring tasks and respond to alerts with triggers
  • Summarize progress and report instantly

Built for teams:

  • Shared memory across your entire org—no repeating context
  • Per-thread sandboxes to safely plan and execute work
  • Governance built-in—scoped access, auditability, and budget controls

One agent for your entire SDLC. Right inside Slack.

👉 Get started


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

Please don't merge this PR until #3259 is merged.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !649 has been created, and the CI pipeline #50630504 is currently running. I'll report back once the pipeline job completes.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for DeepSeek V4 sparse MLA kernels, enabling efficient attention computation with dynamic top-k lengths. The changes include updates to the kernel launchers, parameter structures, and the addition of a new public API, trtllm_batch_decode_sparse_mla_dsv4, along with comprehensive test coverage. I have reviewed the changes and suggest replacing the assert statements used for input validation in the public API with explicit TypeError raises to ensure robust validation, as assertions can be disabled in production environments.

Comment thread flashinfer/mla/_core.py Outdated
Copy link
Copy Markdown
Contributor

@coderabbitai coderabbitai Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
include/flashinfer/trtllm/fmha/fmhaKernels.cuh (1)

157-205: ⚠️ Potential issue | 🟠 Major | 🏗️ Heavy lift

Disambiguate dynamic-page kernels from static 128-page kernels in the hash.

selectNumTokensPerPage() rewrites every eligible paged decode with page_size >= 128 to the sentinel key 128, but hashID() still hashes only numTokensPerPageLog2. That makes a dynamic 256/512/... request indistinguishable from a real static-128 kernel, so we can silently load the wrong cubin or hit metadata hash conflicts when both kernel families are present. Please encode mDynamicNumTokensPerPage in the kernel hash, or reserve a sentinel that can never overlap a real page size.

Also applies to: 384-401, 900-935

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh` around lines 157 - 205,
hashID() currently only uses numTokensPerPageLog2 so dynamic paged kernels
rewritten to sentinel 128 by selectNumTokensPerPage() collide with real
static-128 kernels; update hashID (and the other affected hash builders around
the same logic) to include the dynamic-page flag (mDynamicNumTokensPerPage or
equivalent) into the key so dynamic vs static-128 are distinguishable: locate
hashID and the other hash construction sites that use numTokensPerPageLog2 and
add a dedicated bit field (or a reserved sentinel bit) for
mDynamicNumTokensPerPage (or a boolean computed from it) when packing the
uint64_t key, ensuring the chosen bit position does not overlap existing fields
(and update any comments documenting bit layout).
🧹 Nitpick comments (5)
flashinfer/mla/_core.py (2)

549-553: 💤 Low value

Unreachable seq_lens is None check.

The signature declares seq_lens: torch.Tensor (non-Optional, line 443), so this branch is effectively dead code. Either drop it, or change the parameter typing to Optional[torch.Tensor] if None is genuinely intended to be supported.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 549 - 553, The check for "if seq_lens
is None" is unreachable because the function signature declares seq_lens:
torch.Tensor (non-Optional); either remove the dead branch or make the parameter
Optional to reflect intended behavior. Fix by either (A) deleting the if-block
and its ValueError if callers always pass a tensor, or (B) change the parameter
annotation to Optional[torch.Tensor] (and import Optional) and keep the runtime
check; reference the seq_lens parameter in the relevant function in mla/_core.py
to apply the chosen change consistently.

561-565: 💤 Low value

torch.any(...).item() forces a CUDA→CPU sync on every decode call.

This bound check synchronizes the stream on every invocation, which can hurt decode throughput when the kernel is otherwise designed to be PDL-friendly. If this is a defensive validation rather than a hard contract, consider running it only when explicit input validation is enabled (e.g., via an env flag or debug guard) so the hot path stays async.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@flashinfer/mla/_core.py` around lines 561 - 565, The torch.any(...).item()
call forces a CUDA→CPU sync in the hot path (the check comparing seq_lens and
q_lens) — move this validation out of the async decode path and run it only when
explicit input validation/debugging is enabled: wrap the existing comparison and
ValueError raise (the block referencing seq_lens, q_lens and
torch.any(...).item()) behind a debug/env guard (e.g., an opt-in flag like
FLASHINFER_VALIDATE_INPUTS or a module-level debug boolean) so normal decoding
avoids the synchronous .item() call; keep the same error message and logic under
that guard so behavior is unchanged when validation is turned on.
csrc/fmhaReduction.cu (1)

405-406: ⚡ Quick win

Avoid magic-number coupling to TrtllmGenSparseMlaType.

kernelMeta.mSparseAttn == 2 implicitly assumes TrtllmGenSparseMlaType::DynamicTokenSparse == 2. If the enum is reordered or extended, this check silently breaks. Cast to the enum or use the existing isDynamicTokenSparseMla predicate from fmhaRunnerParams.h so the relationship is checked by the type system.

-  bool const supportsVarSparseMlaTopKLens =
-      kernelMeta.mSparseAttn == 2 && kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512;
+  bool const supportsVarSparseMlaTopKLens =
+      isDynamicTokenSparseMla(static_cast<TrtllmGenSparseMlaType>(kernelMeta.mSparseAttn)) &&
+      kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512;
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/fmhaReduction.cu` around lines 405 - 406, The boolean
supportsVarSparseMlaTopKLens currently compares kernelMeta.mSparseAttn to the
magic number 2; change this to use the enum or predicate instead—replace the
check "kernelMeta.mSparseAttn == 2" with a type-safe check using
TrtllmGenSparseMlaType (casted) or, preferably, call the existing
isDynamicTokenSparseMla(...) predicate from fmhaRunnerParams.h so the condition
becomes isDynamicTokenSparseMla(kernelMeta.mSparseAttn) && kernelMeta.mHeadDimQk
== 512 && kernelMeta.mHeadDimV == 512; keep the rest of the expression and the
variable name supportsVarSparseMlaTopKLens unchanged.
include/flashinfer/trtllm/fmha/fmhaRunnerParams.h (1)

334-339: 💤 Low value

Redundant mSparseMla boolean duplicates mSparseMlaType state.

mSparseMla is fully derivable from isSparseMla(mSparseMlaType), and keeping both fields as the source of truth invites the two from drifting (e.g., a future call site sets one and forgets the other). Since you are already touching the runner-params surface here, consider replacing the field with a small helper accessor and updating the call sites that read mSparseMla to call isSparseMla(mSparseMlaType) directly.

🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h` around lines 334 - 339,
Remove the redundant member mSparseMla and stop storing sparse state separately;
derive it from mSparseMlaType instead. Add a small const accessor on the class
(e.g., isSparseMla() const) that returns isSparseMla(mSparseMlaType), keep
mSparseMlaType and mSparseMlaTopK as-is, and update all call sites that read
mSparseMla to call the new accessor (or call isSparseMla(mSparseMlaType)
directly). Ensure no assignments to the removed mSparseMla remain.
csrc/trtllm_fmha_kernel_launcher.cu (1)

785-790: 💤 Low value

Inconsistent is_4bit argument for sliding-window head-dim computation.

head_dim_k checks is_4bit(kv_data_type) while head_dim_sw is computed against is_4bit(q_data_type). These happen to be equivalent here because of the dtype equality checks at lines 759-761, but reading the head-dim of the SWA tensor against the query's dtype is misleading. Use dl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype()) (or kv_data_type since it's enforced equal) for consistency.

-  int const head_dim_sw = is_4bit(q_data_type) ? sliding_window_kv_cache.size(-1) * 2
-                                               : sliding_window_kv_cache.size(-1);
+  int const head_dim_sw = is_4bit(kv_data_type) ? sliding_window_kv_cache.size(-1) * 2
+                                                : sliding_window_kv_cache.size(-1);
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@csrc/trtllm_fmha_kernel_launcher.cu` around lines 785 - 790, The computation
of head_dim_sw uses is_4bit(q_data_type) which is misleading; update the
condition to check the sliding-window tensor's actual dtype (either use
is_4bit(kv_data_type) since dtype equality is enforced, or call
is_4bit(dl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype()))).
Specifically, modify the head_dim_sw expression so it references
sliding_window_kv_cache's dtype (or kv_data_type) with is_4bit instead of
q_data_type to make the intent and calculation consistent (see head_dim_sw,
is_4bit, sliding_window_kv_cache, kv_data_type).
🤖 Prompt for all review comments with AI agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

Inline comments:
In `@flashinfer/mla/_core.py`:
- Around line 306-321: The mypy error comes from assigning different-length
tuples to out_shape; fix it by adding an explicit widened type annotation (e.g.
typing.Tuple[int, ...]) for out_shape (and similarly sparse_indices_prefix_shape
if needed) in the function scope so both the 3-tuple and 4-tuple assignments are
type-compatible; locate the assignments to out_shape in the _core.py branch that
handles dense vs sparse queries (symbols: out_shape,
sparse_indices_prefix_shape, query_flat) and change their declaration to use
Tuple[int, ...] from typing.
- Around line 246-253: The device check in the call to check_shape_dtype_device
for cum_seq_lens_q is tautological because it passes cum_seq_lens_q.device as
the expected device; change this to compare against the caller's reference
device (for example use topk_lens.device or the query tensor's device) so
cum_seq_lens_q is validated to live on the same device as the other tensors, or
remove the device argument entirely if no cross-tensor device validation is
needed; update the call site where cum_seq_lens_q is validated (the call to
check_shape_dtype_device) to use the chosen reference device instead of
cum_seq_lens_q.device.

In `@tests/attention/test_trtllm_gen_sparse_mla_dsv4.py`:
- Around line 727-731: The _skip_unless_sm100_or_sm103 helper currently checks
only major CUDA compute capability 10; change it to require exact (10,0) or
(10,3) tuples returned by get_compute_capability(torch.device("cuda")) so only
SM100 or SM103 pass. In other words, update the conditional in
_skip_unless_sm100_or_sm103 to compare the full tuple against (10, 0) and (10,
3) (using the existing get_compute_capability call) and call pytest.skip when
the device tuple is not one of those two values.

---

Outside diff comments:
In `@include/flashinfer/trtllm/fmha/fmhaKernels.cuh`:
- Around line 157-205: hashID() currently only uses numTokensPerPageLog2 so
dynamic paged kernels rewritten to sentinel 128 by selectNumTokensPerPage()
collide with real static-128 kernels; update hashID (and the other affected hash
builders around the same logic) to include the dynamic-page flag
(mDynamicNumTokensPerPage or equivalent) into the key so dynamic vs static-128
are distinguishable: locate hashID and the other hash construction sites that
use numTokensPerPageLog2 and add a dedicated bit field (or a reserved sentinel
bit) for mDynamicNumTokensPerPage (or a boolean computed from it) when packing
the uint64_t key, ensuring the chosen bit position does not overlap existing
fields (and update any comments documenting bit layout).

---

Nitpick comments:
In `@csrc/fmhaReduction.cu`:
- Around line 405-406: The boolean supportsVarSparseMlaTopKLens currently
compares kernelMeta.mSparseAttn to the magic number 2; change this to use the
enum or predicate instead—replace the check "kernelMeta.mSparseAttn == 2" with a
type-safe check using TrtllmGenSparseMlaType (casted) or, preferably, call the
existing isDynamicTokenSparseMla(...) predicate from fmhaRunnerParams.h so the
condition becomes isDynamicTokenSparseMla(kernelMeta.mSparseAttn) &&
kernelMeta.mHeadDimQk == 512 && kernelMeta.mHeadDimV == 512; keep the rest of
the expression and the variable name supportsVarSparseMlaTopKLens unchanged.

In `@csrc/trtllm_fmha_kernel_launcher.cu`:
- Around line 785-790: The computation of head_dim_sw uses is_4bit(q_data_type)
which is misleading; update the condition to check the sliding-window tensor's
actual dtype (either use is_4bit(kv_data_type) since dtype equality is enforced,
or call is_4bit(dl_dtype_to_tllm_data_type(sliding_window_kv_cache.dtype()))).
Specifically, modify the head_dim_sw expression so it references
sliding_window_kv_cache's dtype (or kv_data_type) with is_4bit instead of
q_data_type to make the intent and calculation consistent (see head_dim_sw,
is_4bit, sliding_window_kv_cache, kv_data_type).

In `@flashinfer/mla/_core.py`:
- Around line 549-553: The check for "if seq_lens is None" is unreachable
because the function signature declares seq_lens: torch.Tensor (non-Optional);
either remove the dead branch or make the parameter Optional to reflect intended
behavior. Fix by either (A) deleting the if-block and its ValueError if callers
always pass a tensor, or (B) change the parameter annotation to
Optional[torch.Tensor] (and import Optional) and keep the runtime check;
reference the seq_lens parameter in the relevant function in mla/_core.py to
apply the chosen change consistently.
- Around line 561-565: The torch.any(...).item() call forces a CUDA→CPU sync in
the hot path (the check comparing seq_lens and q_lens) — move this validation
out of the async decode path and run it only when explicit input
validation/debugging is enabled: wrap the existing comparison and ValueError
raise (the block referencing seq_lens, q_lens and torch.any(...).item()) behind
a debug/env guard (e.g., an opt-in flag like FLASHINFER_VALIDATE_INPUTS or a
module-level debug boolean) so normal decoding avoids the synchronous .item()
call; keep the same error message and logic under that guard so behavior is
unchanged when validation is turned on.

In `@include/flashinfer/trtllm/fmha/fmhaRunnerParams.h`:
- Around line 334-339: Remove the redundant member mSparseMla and stop storing
sparse state separately; derive it from mSparseMlaType instead. Add a small
const accessor on the class (e.g., isSparseMla() const) that returns
isSparseMla(mSparseMlaType), keep mSparseMlaType and mSparseMlaTopK as-is, and
update all call sites that read mSparseMla to call the new accessor (or call
isSparseMla(mSparseMlaType) directly). Ensure no assignments to the removed
mSparseMla remain.
🪄 Autofix (Beta)

Fix all unresolved CodeRabbit comments on this PR:

  • Push a commit to this branch (recommended)
  • Create a new PR with the fixes

ℹ️ Review info
⚙️ Run configuration

Configuration used: defaults

Review profile: CHILL

Plan: Pro

Run ID: 454a6d85-dba9-42db-ac45-254e5f2a4cba

📥 Commits

Reviewing files that changed from the base of the PR and between 2fe28ca and a78471cf62219dde9bf062065bf06f0a7c94ea87.

📒 Files selected for processing (13)
  • .gitignore
  • csrc/fmhaReduction.cu
  • csrc/trtllm_fmha_kernel_launcher.cu
  • flashinfer/artifacts.py
  • flashinfer/decode.py
  • flashinfer/jit/env.py
  • flashinfer/mla/_core.py
  • include/flashinfer/trtllm/fmha/fmhaKernels.cuh
  • include/flashinfer/trtllm/fmha/fmhaRunnerParams.h
  • include/flashinfer/trtllm/fmha/kernelParams.h
  • tests/attention/test_attention_sink_blackwell.py
  • tests/attention/test_trtllm_gen_attention.py
  • tests/attention/test_trtllm_gen_sparse_mla_dsv4.py

Comment thread flashinfer/mla/_core.py
Comment thread flashinfer/mla/_core.py
Comment thread tests/attention/test_trtllm_gen_sparse_mla_dsv4.py Outdated
@PerkzZheng PerkzZheng force-pushed the dsv4-sparse-mla-trtllm-gen branch 2 times, most recently from f80b7d2 to b351a08 Compare May 11, 2026 06:16
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !649 has been updated with latest changes, and the CI pipeline #50885037 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng PerkzZheng force-pushed the dsv4-sparse-mla-trtllm-gen branch from b351a08 to 255870c Compare May 11, 2026 08:40
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !649 has been updated with latest changes, and the CI pipeline #50895754 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

@saltyminty @qsang-nv please help review this. Thanks!

Add the TRTLLM-GEN dynamic sparse MLA launch path for DeepSeek V4, including the Python wrapper, sparse/SWA metadata plumbing, FlashMLA-style coverage for BF16 and per-tensor FP8 cases, updated remote FMHA cubin metadata, and an ignore rule for local generated cubins.

Tests: python -m py_compile flashinfer/artifacts.py

Tests: python -m ruff check flashinfer/artifacts.py

Tests: FLASHINFER_CUBINS_REPOSITORY=<private URM repository> FLASHINFER_CUBIN_DIR unset FLASHINFER_CUBIN_CHECKSUM_DISABLED unset FLASHINFER_DISABLE_VERSION_CHECK=1 FLASHINFER_WORKSPACE_BASE=/home/scratch.perkzz_gpu/dpsv4/flashinfer/.jit_dsv4_remote_artifacts CUDA_LAUNCH_BLOCKING=1 python -m pytest -q --tb=short tests/attention/test_trtllm_gen_sparse_mla_dsv4.py
@PerkzZheng PerkzZheng force-pushed the dsv4-sparse-mla-trtllm-gen branch from 255870c to f578d94 Compare May 12, 2026 02:44
@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !649 has been updated with latest changes, and the CI pipeline #50989195 is currently running. I'll report back once the pipeline job completes.

@PerkzZheng
Copy link
Copy Markdown
Contributor Author

/bot run

@flashinfer-bot
Copy link
Copy Markdown
Collaborator

GitLab MR !649 has been created, and the CI pipeline #50989237 is currently running. I'll report back once the pipeline job completes.

Comment thread flashinfer/mla/_core.py
query.device,
cum_seq_lens_q,
)
if normalized_sparse_lens.numel() > 0:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sparse_topk_lens min/max check in _check_dsv4_sparse_mla_inputs does two .item() D2H syncs per call. The sibling seq_lens < q_lens check is already gated by _validate_dsv4_sync_checks() — is the gate intentionally omitted here, or a miss?

return 512 if h_q == 64 else 1024


def gen_testcase() -> tuple[RawTestParamForDecode, ...]:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All add_case(...) calls in gen_testcase() hardcode is_varlen=True, so the dense 4D-query branch in _check_dsv4_sparse_mla_inputs (the else arm with query.ndim == 4, out_shape = (batch_size, q_len_per_request, num_heads, 512)) and the not is_varlen arm of _make_q_lens are never exercised. Could you add at least one is_varlen=False case so this path has coverage?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants