Skip to content

Hang in oneshotAllreduceFusionKernel During Piecewise CUDA Graph Replay #3053

@wenscarl

Description

@wenscarl

A permanent GPU hang occurs during piecewise CUDA graph replay when using FlashInfer's MNNVL (Multi-Node NVLink) one-shot allreduce kernel (oneshotAllreduceFusionKernel<2, bfloat16,
true, float4>). The hang manifests as a Lamport clock spin-wait that never exits, cascading into warp-level __shfl_xor_sync deadlock and cluster barrier deadlock. The same bug is
independently observed in vLLM: vllm-project/vllm#35772

Environment:

  GPU         2x NVIDIA GB200 (SM100, MNNVL fabric)
  CUDA        12.9
  FlashInfer  0.6.4
  Framework   sglang (also vLLM — see linked issue)
  Container   lmsysorg/sglang:dev.

Hang trigger

The hang occurs during piecewise CUDA graph replay — sglang captures one piecewise CUDA graph per prefill num_tokens and replays the best match at inference time. The hang does not occur during decode cuda graph replay or eager mode — only during piecewise graph replay.

Observed Behavior

Both GPUs hang permanently. cuda-gdb shows the following state:

  GPU 0 — oneshotAllreduceFusionKernel<2, bfloat16, true, float4> (rank 0)

  Block (1,0,0):

  Thread 18  — STUCK at Lamport spin-wait
    #0  ld.volatile.global.v4.f32  (loadPackedVolatile, trtllm_mnnvl_allreduce.cuh:335)
    #1  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:563)

  Thread 109 — STUCK at Lamport spin-wait
    #0  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:567)  [isNegZero check]

  Thread 176 — STUCK at Lamport spin-wait
    #0  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:560)

  Threads 0–17, 19–31 — DEADLOCKED at __shfl_xor_sync (warpReduceSumFull, warp 0 diverged)
    #0  __shfl_xor_sync(0xffffffff, ...)     (sm_30_intrinsics.hpp:449)
    #1  warpReduceSumFull<float>             (trtllm_mnnvl_allreduce.cuh:364)
    #2  blockReduceSumPartial<float, true>   (trtllm_mnnvl_allreduce.cuh:396)
    #3  blockReduceSum<float, true>          (trtllm_mnnvl_allreduce.cuh:443)
    #4  oneshotAllreduceFusionKernel         (trtllm_mnnvl_allreduce.cuh:616)

  Threads 32–95, 128–159 — BLOCKED at __syncthreads()
    #0  __syncthreads()                      (trtllm_mnnvl_allreduce.cuh:402)
    #1  blockReduceSumPartial<float, true>   (trtllm_mnnvl_allreduce.cuh:402)
    #2  blockReduceSum<float, true>          (trtllm_mnnvl_allreduce.cuh:443)
    #3  oneshotAllreduceFusionKernel         (trtllm_mnnvl_allreduce.cuh:616)

  Block (2,0,0): Same pattern — threads 15, 37, 77 stuck in spin-wait; remaining threads in warp at __shfl_xor_sync or __syncthreads().

  Block (3,0,0): ALL 180 threads at cluster barrier (cluster partner block (3,1,0) never arrives):
    #0  cluster.barrier_wait(...)            (sm_90_rt.hpp:197)
    #1  oneshotAllreduceFusionKernel         (trtllm_mnnvl_allreduce.cuh:631)

  Block (3,1,0): Threads 25 and 98 stuck in spin-wait; remaining threads at __syncthreads().

  GPU 1 — oneshotAllreduceFusionKernel<2, bfloat16, true, float4> (rank 1)

  Symmetric pattern across all blocks. Representative stack:

  Thread 18  (block 1,0,0) — STUCK at Lamport spin-wait
    #0  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:567)

  Thread 109 (block 1,0,0) — STUCK at loadPackedVolatile
    #0  ld.volatile.global.v4.f32   (trtllm_mnnvl_allreduce.cuh:336)
    #1  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:563)

  Thread 176 (block 1,0,0) — STUCK in isNegZero
    #0  isNegZero<float>             (trtllm_mnnvl_allreduce.cuh:114)
    #1  oneshotAllreduceFusionKernel (trtllm_mnnvl_allreduce.cuh:567)

  Block (3,0,0) on GPU 1 is also fully blocked at cluster barrier sm_90_rt.hpp:197.

  bufferFlags state at time of hang (both GPUs)

  bufferFlags[0]  = 1   (mCurrentIndex — using Lamport buffer 1)
  bufferFlags[1]  = 0   (mDirtyIndex)
  bufferFlags[2]  = <buffer_size_bytes>

  Both GPUs show identical values, ruling out mCurrentIndex mismatch between ranks.

Brief analysis:

Within each block, a subset of threads is permanently stuck in the Lamport spin-wait (waiting for multicast-delivered data to replace the -0.0f float32 sentinel). The remaining threads in the same warps have already exited the spin-wait and entered blockReduceSumPartial, where they are now blocked — either at __shfl_xor_sync(0xFFFFFFFF, ...) (requires all 32 warp lanes) or at __syncthreads() — waiting for the spin-wait threads to catch up. Since those threads never exit the spin-wait, the entire block is deadlocked. Blocks that fully completed are then blocked at the cluster barrier waiting for their cluster partner that is also deadlocked.

The hang is specific to piecewise CUDA graph replay and does not reproduce in eager mode, suggesting an interaction between the Lamport buffer state machine and CUDA graph capture/replay semantics.

Steps to reproduce:

1. Pull and start the container on a GB200 node,
  docker pull lmsysorg/sglang:dev
  docker run -it --rm \
    --gpus all \
    -w /sgl-workspace \
    --network host \
    --shm-size 16g \
    --cap-add SYS_PTRACE \
    --ulimit memlock=-1 \
    --privileged \
    lmsysorg/sglang:dev \
    bash

2. Inside the container — install FlashInfer 0.6.4
  pip install flashinfer-python==0.6.4 \
    flashinfer-cubin==0.6.4 \
    flashinfer-jit-cache==0.6.4 \
    --index-url https://flashinfer.ai/whl/cu129
3. stall the debug branch of sglang:
  cd /sgl-workspace
  mv sglang sglang_old
  git clone https://github.com/wenscarl/sglang.git -b ar_debug
  cd sglang
  git checkout ea754b82603eaa9f9f407710945d5ea966053384
4. Run the server:   
CUDA_LAUNCH_BLOCKING=1 
    sglang serve \
      --model-path openai/gpt-oss-120b \
      --tensor-parallel-size 2 \
      --reasoning-parser gpt-oss \
      --tool-call-parser gpt-oss \
      --enable-flashinfer-allreduce-fusion \
      --disable-flashinfer-autotune
Wait for the server to print INFO: Application startup complete before proceeding
5. Run the client:
  python3 -m sglang.bench_serving \
    --backend sglang \
    --num-prompt 256 \
    --max-concurrency 64
  The hang typically manifests within the first few requests during prefill (piecewise CUDA graph replay).

6. CUDA-GDB
nvidia-smi to find the pid
then cuda-gdb -p pid

Metadata

Metadata

Type

No type

Projects

Status

In Progress

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions