Skip to content

Reenable MNNVL backend for FlashInfer allreduce fusion#23402

Open
wenscarl wants to merge 25 commits into
sgl-project:mainfrom
wenscarl:ar_debug2
Open

Reenable MNNVL backend for FlashInfer allreduce fusion#23402
wenscarl wants to merge 25 commits into
sgl-project:mainfrom
wenscarl:ar_debug2

Conversation

@wenscarl
Copy link
Copy Markdown
Collaborator

@wenscarl wenscarl commented Apr 21, 2026

Summary

Reintroduces FlashInfer unified fused allreduce + residual + RMSNorm with an explicit backend flag --flashinfer-allreduce-fusion-backend (auto | trtllm | mnnvl), and fixes interaction with piecewise CUDA graph when the MNNVL path is used.

Background

  • Original integration: PR #12787.
  • Reverted due to hangs (e.g. some models / CI): PR #20792.
  • Related report (fused AR + RMSNorm + piecewise graph): vLLM #35772.

Problem

Hangs were correlated with piecewise CUDA graph capture/replay while MNNVL-style allreduce fusion was active, including configurations where the fused op was already a torch.compile split op relative to PCG. Reducing the number of piecewise graphs made the hang much rarer. Eager execution and the decode (non–piecewise) CUDA graph path did not reproduce the issue in our testing.
When stuck, stacks pointed to lack of progress in the Lamport-style wait loop in FlashInfer’s MNNVL allreduce implementation, e.g. trtllm_mnnvl_allreduce.cuh (lines ~557–573), consistent with oneshot/twoshot progress assumptions conflicting with PCG replay.

What this PR does

  • Adds --flashinfer-allreduce-fusion-backend and gates fusion on flashinfer_allreduce_fusion_backend is not None (see communicator.py, server_args.py).
  • mnnvl or auto on SM100 (where MNNVL may be selected): disables piecewise CUDA graph in model_runner.py so MNNVL fusion is not replayed inside PCG; keeps flashinfer_allreduce_residual_rmsnorm registered as a split op so it runs eagerly between graph pieces.
  • trtllm: removes that split-op name from PCG split_ops when MNNVL split is not required, so fusion can stay in-graph for piecewise compile.
  • Layernorm: if FlashInfer fusion returns (None, None), always performs tensor_model_parallel_all_reduce before RMSNorm (fixes missing allreduce on fallback).
  • Deprecates --enable-flashinfer-allreduce-fusion: if set with no backend, maps to --flashinfer-allreduce-fusion-backend=auto and logs a warning.
  • Workspace creation uses FlashInfer create_allreduce_fusion_workspace with backend=…, optional gpus_per_node, and preserves the existing NCCL device + GLOO cpu TorchDistBackend workaround where applicable.

Benchmarks (Gb200 4GPUs)

Server:

sglang serve \
  --model-path openai/gpt-oss-120b \
  --tensor-parallel-size 2 \
  --reasoning-parser gpt-oss \
  --tool-call-parser gpt-oss \
  --flashinfer-allreduce-fusion-backend mnnvl or trtllm \
  --disable-flashinfer-autotune

Client (bench_serving):

  python3 -m sglang.bench_serving \
    --backend sglang \
    --dataset-name random \
    --random-input-len 1024 \
    --random-output-len 1024 \
    --random-range-ratio 1.0 \
    --num-prompts 128 \
    --max-concurrency ${BS} \
    --request-rate inf \
    --disable-ignore-eos
Max request concurrency MNNVL TPOT (ms) TRT-LLM TPOT (ms) (piecewise CUDA graph) Speedup vs TRT-LLM (piecewise CUDA graph) TRT-LLM TPOT (ms) (without piecewise CUDA graph) Speedup vs TRT-LLM (without piecewise CUDA graph)
1 3.30 3.79 1.15x 3.79 1.15x
4 4.25 4.75 1.12x 4.83 1.14x
16 5.93 6.76 1.14x 6.87 1.16x
32 7.16 8.34 1.17x 8.57 1.20x
64 8.93 10.60 1.19x 10.83 1.21x

Notes for reviewers

  • MNNVL + PCG is treated as unsafe (similar policy to other non–graph-safe comm); TRT-LLM path remains the default for single-node auto where applicable.
  • FlashInfer / topology details are documented in comments in flashinfer_comm_fusion.py.

Accuracy:

python3 -m sglang.test.few_shot_gsm8k --num-questions 200

Accuracy: 0.875
Invalid: 0.020
Latency: 27.135 s
Output throughput: 2344.613 token/s

How to reproduce the hang

  1. Goal
    You want both:

FlashInfer MNNVL (or auto on SM100 where MNNVL is chosen) for fused allreduce + residual + RMSNorm, and
Piecewise CUDA graph capture/replay still enabled (the path this PR disables in model_runner.py).
On this PR’s default code, piecewise CUDA graph is skipped when flashinfer_ar_needs_piecewise_cuda_graph_split(server_args) is true, so you will not see the hang until you relax that guard.

  1. Code change (force the bad combination)
    Option A — turn piecewise CUDA graph back on (most direct)

In python/sglang/srt/model_executor/model_runner.py, comment out or remove the early return that runs when flashinfer_ar_needs_piecewise_cuda_graph_split is true (the block that logs “Disable piecewise CUDA graph because MNNVL allreduce fusion is enabled”).

Option B — make the “disable PCG” helper always false

In python/sglang/srt/layers/flashinfer_comm_fusion.py, change flashinfer_ar_needs_piecewise_cuda_graph_split so it always returns False (or only returns False for your MNNVL / auto case). Then model_runner will not skip piecewise CUDA graph, while piecewise_cuda_graph_runner can still treat the fused op as a split op for MNNVL when your branch logic keeps it in split_ops.

You reported that even Option B still hung; that is why the shipped fix disables PCG entirely for MNNVL instead of relying on split-op alone.
@nvpohanh


CI States

Latest PR Test (Base): Run #25962704277
Latest PR Test (Extra): ⚠️ Not enabled — add run-ci-extra label to opt in.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

# Conflicts:
#	python/sglang/srt/layers/flashinfer_comm_fusion.py
@wenscarl wenscarl marked this pull request as ready for review May 5, 2026 22:04
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Warning

You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again!

@wenscarl wenscarl requested a review from b8zhong May 6, 2026 03:38
Comment thread python/sglang/srt/layers/flashinfer_comm_fusion.py Outdated
Comment thread python/sglang/srt/server_args.py Outdated
"Enable FlashInfer allreduce fusion and choose backend. "
"When not set the feature is disabled. "
"'auto': choose best backend (trtllm single-node, mnnvl multi-node). "
"'trtllm': single-node only, supports fused quantization. "
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, I'm not sure we use the quantize fusion in code (maybe only in benchmarks, can you help verify)

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verified — only kARResidualRMSNorm (no quant) is called in SGLang. The quant patterns appear only in flashinfer benchmarks. Removed the misleading line from the help text.

- simplify flashinfer.comm imports (drop legacy fallbacks; rely on pinned version)
- drop "supports fused quantization" from trtllm backend help text
@wenscarl wenscarl requested a review from b8zhong May 7, 2026 14:31
Copy link
Copy Markdown
Collaborator

@b8zhong b8zhong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

/tag-and-rerun-ci

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented May 8, 2026

@wenscarl could you resolve the conflicts?

@b8zhong b8zhong added the run-ci label May 8, 2026
@wenscarl wenscarl requested a review from wisclmy0611 as a code owner May 8, 2026 15:32
auto-merge was automatically disabled May 12, 2026 17:56

Head branch was pushed to by a user without write access

Comment thread python/sglang/srt/layers/flashinfer_comm_fusion.py Outdated
@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented May 13, 2026

Screenshot 2026-05-13 at 11 12 44 AM

Even though it does not use trigger completion at end, the E2E speedup will be more influential (first image is MNNVL). So it's fine. Just verified on B300

Screenshot 2026-05-13 at 11 13 09 AM

@wenscarl
Copy link
Copy Markdown
Collaborator Author

gist for debugging hang.

@wenscarl wenscarl requested a review from b8zhong May 13, 2026 16:45
@b8zhong b8zhong self-assigned this May 13, 2026
@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented May 14, 2026

@wenscarl The MTP failure looks related on DSV32. Can you help take a look

Edit: it's not related. Just tested it locally

Copy link
Copy Markdown
Collaborator

@BBuf BBuf left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One more concern: the deterministic inference path does not seem fully migrated to the new backend field.

In server_args.py, _handle_deterministic_inference() still only disables the old enable_flashinfer_allreduce_fusion flag, but it does not clear flashinfer_allreduce_fusion_backend. This matters especially for rl_on_policy_target, because deterministic inference is enabled inside _handle_deterministic_inference(), after the earlier enforce_disable_flashinfer_allreduce_fusion handling has already run.

So some deterministic inference configurations may still end up with the new FlashInfer allreduce fusion backend enabled. Should we also set flashinfer_allreduce_fusion_backend = None in this path?

@BBuf
Copy link
Copy Markdown
Collaborator

BBuf commented May 15, 2026

The current CI also does not look clean yet. A few failures seem worth checking before merge:

  • B200: test/registered/quant/test_deepseek_v32_fp4_mtp_4gpu.py fails with acc_length=1.00, hitting AssertionError: 1.0 not greater than 2.7. This looks especially relevant to the allreduce/RMSNorm correctness risk.
  • H200: one test fails a speed threshold, with 177.33 < 180.
  • H100: test_gpt_oss_4gpu.py fails and produces CUDA coredumps.

@b8zhong
Copy link
Copy Markdown
Collaborator

b8zhong commented May 15, 2026

I ran the B200 test locally, and it could pass
This feature is not enabled on H200
The H100 is known failure on main

@wenscarl wenscarl requested a review from BBuf May 15, 2026 22:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bypass-fastfail documentation Improvements or additions to documentation run-ci

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants