Reenable MNNVL backend for FlashInfer allreduce fusion#23402
Conversation
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
# Conflicts: # python/sglang/srt/layers/flashinfer_comm_fusion.py
|
Warning You have reached your daily quota limit. Please wait up to 24 hours and I will start processing your requests again! |
| "Enable FlashInfer allreduce fusion and choose backend. " | ||
| "When not set the feature is disabled. " | ||
| "'auto': choose best backend (trtllm single-node, mnnvl multi-node). " | ||
| "'trtllm': single-node only, supports fused quantization. " |
There was a problem hiding this comment.
Currently, I'm not sure we use the quantize fusion in code (maybe only in benchmarks, can you help verify)
There was a problem hiding this comment.
Verified — only kARResidualRMSNorm (no quant) is called in SGLang. The quant patterns appear only in flashinfer benchmarks. Removed the misleading line from the help text.
- simplify flashinfer.comm imports (drop legacy fallbacks; rely on pinned version) - drop "supports fused quantization" from trtllm backend help text
|
@wenscarl could you resolve the conflicts? |
Head branch was pushed to by a user without write access
|
gist for debugging hang. |
|
@wenscarl The MTP failure looks related on DSV32. Can you help take a look Edit: it's not related. Just tested it locally |
BBuf
left a comment
There was a problem hiding this comment.
One more concern: the deterministic inference path does not seem fully migrated to the new backend field.
In server_args.py, _handle_deterministic_inference() still only disables the old enable_flashinfer_allreduce_fusion flag, but it does not clear flashinfer_allreduce_fusion_backend. This matters especially for rl_on_policy_target, because deterministic inference is enabled inside _handle_deterministic_inference(), after the earlier enforce_disable_flashinfer_allreduce_fusion handling has already run.
So some deterministic inference configurations may still end up with the new FlashInfer allreduce fusion backend enabled. Should we also set flashinfer_allreduce_fusion_backend = None in this path?
|
The current CI also does not look clean yet. A few failures seem worth checking before merge:
|
|
I ran the B200 test locally, and it could pass |


Summary
Reintroduces FlashInfer unified fused allreduce + residual + RMSNorm with an explicit backend flag
--flashinfer-allreduce-fusion-backend(auto|trtllm|mnnvl), and fixes interaction with piecewise CUDA graph when the MNNVL path is used.Background
Problem
Hangs were correlated with piecewise CUDA graph capture/replay while MNNVL-style allreduce fusion was active, including configurations where the fused op was already a torch.compile split op relative to PCG. Reducing the number of piecewise graphs made the hang much rarer. Eager execution and the decode (non–piecewise) CUDA graph path did not reproduce the issue in our testing.
When stuck, stacks pointed to lack of progress in the Lamport-style wait loop in FlashInfer’s MNNVL allreduce implementation, e.g.
trtllm_mnnvl_allreduce.cuh(lines ~557–573), consistent with oneshot/twoshot progress assumptions conflicting with PCG replay.What this PR does
--flashinfer-allreduce-fusion-backendand gates fusion onflashinfer_allreduce_fusion_backend is not None(seecommunicator.py,server_args.py).mnnvlorautoon SM100 (where MNNVL may be selected): disables piecewise CUDA graph inmodel_runner.pyso MNNVL fusion is not replayed inside PCG; keepsflashinfer_allreduce_residual_rmsnormregistered as a split op so it runs eagerly between graph pieces.trtllm: removes that split-op name from PCGsplit_opswhen MNNVL split is not required, so fusion can stay in-graph for piecewise compile.(None, None), always performstensor_model_parallel_all_reducebefore RMSNorm (fixes missing allreduce on fallback).--enable-flashinfer-allreduce-fusion: if set with no backend, maps to--flashinfer-allreduce-fusion-backend=autoand logs a warning.create_allreduce_fusion_workspacewithbackend=…, optionalgpus_per_node, and preserves the existing NCCL device + GLOO cpuTorchDistBackendworkaround where applicable.Benchmarks (Gb200 4GPUs)
Server:
Client (
bench_serving):python3 -m sglang.bench_serving \ --backend sglang \ --dataset-name random \ --random-input-len 1024 \ --random-output-len 1024 \ --random-range-ratio 1.0 \ --num-prompts 128 \ --max-concurrency ${BS} \ --request-rate inf \ --disable-ignore-eosNotes for reviewers
autowhere applicable.flashinfer_comm_fusion.py.Accuracy:
How to reproduce the hang
You want both:
FlashInfer MNNVL (or auto on SM100 where MNNVL is chosen) for fused allreduce + residual + RMSNorm, and
Piecewise CUDA graph capture/replay still enabled (the path this PR disables in model_runner.py).
On this PR’s default code, piecewise CUDA graph is skipped when flashinfer_ar_needs_piecewise_cuda_graph_split(server_args) is true, so you will not see the hang until you relax that guard.
Option A — turn piecewise CUDA graph back on (most direct)
In python/sglang/srt/model_executor/model_runner.py, comment out or remove the early return that runs when flashinfer_ar_needs_piecewise_cuda_graph_split is true (the block that logs “Disable piecewise CUDA graph because MNNVL allreduce fusion is enabled”).
Option B — make the “disable PCG” helper always false
In python/sglang/srt/layers/flashinfer_comm_fusion.py, change flashinfer_ar_needs_piecewise_cuda_graph_split so it always returns False (or only returns False for your MNNVL / auto case). Then model_runner will not skip piecewise CUDA graph, while piecewise_cuda_graph_runner can still treat the fused op as a split op for MNNVL when your branch logic keeps it in split_ops.
You reported that even Option B still hung; that is why the shipped fix disables PCG entirely for MNNVL instead of relying on split-op alone.
@nvpohanh
CI States
Latest PR Test (Base): Run #25962704277⚠️ Not enabled — add
Latest PR Test (Extra):
run-ci-extralabel to opt in.