Skip to content

Commit c92b80d

Browse files
committed
Default to FlashInfer GDN decode on SM100+ with bf16 mamba state
On SM100+ with mamba-ssm-dtype=bfloat16, automatically set --linear-attn-decode-backend to flashinfer when not explicitly specified. This gives 1-5% TPOT improvement at higher concurrencies. The prerequisite bug (OOB from negative padding indices in bf16 decode kernel) was fixed in FlashInfer v0.6.7 via flashinfer-ai/flashinfer#2810. Verified on Qwen3.5-397B-A17B-NVFP4 (4xGB200, no_buffer + disable-radix-cache), sa-bench ISL=1024 OSL=1024, conc 2-1024: - GSM8K accuracy: 0.977-0.979 - Mean TPOT: -1.3% (conc=2) to -4.5% (conc=1024) - Excluded when MTP speculative decoding is active (not yet supported) - Output throughput: +1.3% (conc=2) to +4.7% (conc=1024)
1 parent f9a4e2c commit c92b80d

1 file changed

Lines changed: 19 additions & 1 deletion

File tree

python/sglang/srt/server_args.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2557,9 +2557,27 @@ def _handle_mamba_backend(self):
25572557
)
25582558

25592559
def _handle_linear_attn_backend(self):
2560-
# SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32.
25612560
import torch
25622561

2562+
# SM100+: default to FlashInfer GDN decode when the user hasn't
2563+
# explicitly chosen a decode backend and mamba-ssm-dtype is bf16
2564+
# (required by FlashInfer GDN on SM100+).
2565+
# Fixed in FlashInfer v0.6.7: flashinfer-ai/flashinfer#2810
2566+
# Excluded when MTP speculative decoding is enabled because
2567+
# FlashInfer GDN MTP verify is not yet supported on SM100+.
2568+
if (
2569+
self.linear_attn_decode_backend is None
2570+
and is_sm100_supported()
2571+
and self.mamba_ssm_dtype == "bfloat16"
2572+
and self.speculative_algorithm is None
2573+
):
2574+
self.linear_attn_decode_backend = "flashinfer"
2575+
logger.info(
2576+
"SM100+ detected with mamba-ssm-dtype=bfloat16, "
2577+
"defaulting --linear-attn-decode-backend to flashinfer."
2578+
)
2579+
2580+
# SM100+ FlashInfer GDN decode requires bf16 state; SM90 uses float32.
25632581
decode = self.linear_attn_decode_backend or self.linear_attn_backend
25642582
if (
25652583
decode == "flashinfer"

0 commit comments

Comments
 (0)