Skip to content

[Temporarily unblock spec v2 qwen3.5]#18906

Closed
vincentzed wants to merge 2 commits into
sgl-project:mainfrom
bzhng-development:vz/spec-v2-qwen3.5-next
Closed

[Temporarily unblock spec v2 qwen3.5]#18906
vincentzed wants to merge 2 commits into
sgl-project:mainfrom
bzhng-development:vz/spec-v2-qwen3.5-next

Conversation

@vincentzed
Copy link
Copy Markdown
Contributor

@vincentzed vincentzed commented Feb 16, 2026

Motivation

Modifications

Uses #15591 and https://github.com/sgl-project/sglang/pull/18808/changes and change in mm handling on top of it.

❯ python -m sglang.test.send_one --stream

Spec v2

SGLANG_ENABLE_SPEC_V2=1 python -m sglang.launch_server --model-path Qwen/Qwen3.5-397B-A17B --tp-size 8 --mem-fraction-static 0.8 --context-length 262144 --reasoning-parser qwen3 --speculative-algo NEXTN --speculative-num-steps 3 --speculative-eagle-topk 1 --speculative-num-draft-tokens 4

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    1.582    |  512   |   3.368    |     323.72      |
+-------------+--------+------------+-----------------+

Spec v1

+-------------+--------+------------+-----------------+
| Latency (s) | Tokens | Acc Length | Speed (token/s) |
+-------------+--------+------------+-----------------+
|    1.717    |  512   |   3.368    |     298.19      |
+-------------+--------+------------+-----------------+
python3 -m lm_eval --model local-completions --model_args "model=Qwen/Qwen3.5-397B-A17B,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=512" --tasks gsm8k_platinum

local-completions ({'model': 'Qwen/Qwen3.5-397B-A17B', 'base_url': 'http://127.0.0.1:30000/v1/completions', 'num_concurrent': 512}), gen_kwargs: ({}), limit: None, num_fewshot: None, batch_size: 1
|    Tasks     |Version|     Filter     |n-shot|  Metric   |   |Value |   |Stderr|
|--------------|------:|----------------|-----:|-----------|---|-----:|---|-----:|
|gsm8k_platinum|      3|flexible-extract|     5|exact_match|↑  |0.8842|±  |0.0092|
|              |       |strict-match    |     5|exact_match|↑  |0.8668|±  |0.0098|

Suspect there is still some opportunity to further improve spec v2 performance, as advantage is still small.
edit: it's because of this nontrivila sync
CleanShot 2026-02-16 at 13 35 20@2x

python3 -m lm_eval --model local-completions --model_args "model=Qwen/Qwen3.5-397B-A17B,base_url=http://127.0.0.1:30000/v1/completions,num_concurrent=512" --tasks gsm8k_platinum

Accuracy Tests

Benchmarking and Profiling

Checklist

Review Process

  1. Ping Merge Oncalls to start the PR flow. See the PR Merge Process.
  2. Get approvals from CODEOWNERS and other reviewers.
  3. Trigger CI tests with comments or contact authorized users to do so.
    • /tag-run-ci-label, /rerun-failed-ci, /tag-and-rerun-ci
  4. After green CI and required approvals, ask Merge Oncalls to merge.

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @vincentzed, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces several enhancements to support speculative decoding v2 (EAGLEWorkerV2) for Qwen3.5, focusing on optimizing Mamba state management and improving multimodal input handling. It integrates an overlap scheduling mechanism to enhance performance and refines the Mamba prefix-cache update logic, particularly after token verification.

Highlights

  • Mamba State Optimization: Introduced a skip_masking parameter in Mamba state updates post-verification to prevent unnecessary cudaStreamSynchronize calls, improving performance for Spec V2 paths where accepted steps are guaranteed non-negative.
  • Overlap Scheduling Integration: Added an enable_overlap_schedule flag to Mamba memory pools and decode request handling, allowing for more efficient buffer management and potentially better performance.
  • Multimodal Input Handling for Spec V2: Enhanced the _draft_extend_for_prefill function in EAGLEWorkerV2 to correctly process multimodal input embeddings during speculative decoding, addressing 'mm handling' mentioned in the PR description.
  • Speculative Decoding V2 Support: Extended _forward_metadata to properly handle DRAFT_EXTEND_V2 mode for filling the draft KV cache and integrated Mamba state tracking into the V2 verification process.
  • Mamba Prefix-Cache Tracking Refinement: Refined the logic for updating mamba_next_track_idx using a new utility method and added mamba_track_indices to batch information for more precise Mamba prefix-cache state management.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/sglang/srt/disaggregation/decode.py
    • Added enable_overlap_schedule parameter to DecodeReqToTokenPool initialization.
    • Updated mamba_ping_pong_track_buffer_size calculation based on enable_overlap_schedule.
  • python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py
    • Modified _forward_metadata to include DRAFT_EXTEND_V2 handling for is_extend() and calculate query_start_loc accordingly.
    • Added skip_masking parameter to update_mamba_state_after_mtp_verify and implemented conditional masking logic.
  • python/sglang/srt/managers/scheduler_output_processor_mixin.py
    • Updated mamba_next_track_idx assignment to use get_mamba_ping_pong_other_idx for both non-spec and speculative decode paths.
  • python/sglang/srt/mem_cache/memory_pool.py
    • Added enable_overlap_schedule parameter to MambaLinearStatePool initialization.
    • Updated mamba_ping_pong_track_buffer_size calculation based on enable_overlap_schedule.
  • python/sglang/srt/model_executor/forward_batch_info.py
    • Adjusted _compute_mrope_positions to correctly handle DRAFT_EXTEND_V2 mode.
  • python/sglang/srt/model_executor/model_runner_kv_cache_mixin.py
    • Integrated the enable_overlap_schedule parameter into the initialization of Mamba-related memory pools.
  • python/sglang/srt/speculative/eagle_info_v2.py
    • Added logic to prepare mamba_track_indices for Mamba prefix-cache state tracking in speculative decoding v2.
  • python/sglang/srt/speculative/eagle_worker_v2.py
    • Modified _draft_extend_for_prefill to accept and pass mm_input_embeds.
    • Modified forward_batch_generation to pass mm_input_embeds to _draft_extend_for_prefill.
    • Introduced _mamba_verify_update method to handle Mamba state updates after verification, including calculating accepted_steps and mamba_steps_to_track, and calling update_mamba_state_after_mtp_verify with skip_masking=True.
Activity
  • No human activity (comments, reviews, progress) was provided in the context.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces changes to support spec v2 for qwen3.5, particularly for models with Mamba-like architectures. The changes include adding a new forward mode DRAFT_EXTEND_V2, optimizing mamba state updates, and improving multimodal input handling in speculative decoding. The code is generally well-structured and the changes are consistent across files. I have a couple of suggestions for minor improvements to enhance maintainability and code clarity.

Comment on lines +563 to +567
req.mamba_next_track_idx = (
batch.req_to_token_pool.get_mamba_ping_pong_other_idx(
req.mamba_next_track_idx
)
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for updating req.mamba_next_track_idx is identical to the one on lines 546-550. To improve maintainability and reduce code duplication, consider extracting this logic into a private helper method.

Comment on lines +871 to +873
to_track_ith = torch.clamp(
tracking_point - seq_lens_pre_verify - 1, min=0
).to(torch.int64)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The torch.clamp(..., min=0) call appears to be redundant. Given that to_track_mask is true, it implies seq_lens_pre_verify // mamba_track_interval < seq_lens_post_verify // mamba_track_interval. This guarantees that tracking_point will be greater than seq_lens_pre_verify, making tracking_point - seq_lens_pre_verify - 1 always non-negative. Removing the clamp can simplify the code without changing the logic.

                to_track_ith = (tracking_point - seq_lens_pre_verify - 1).to(torch.int64)

Signed-off-by: vincentzed <207368749+vincentzed@users.noreply.github.com>
aabbccddwasd added a commit to aabbccddwasd/sglang that referenced this pull request Feb 20, 2026
@aabbccddwasd
Copy link
Copy Markdown

not working on my machine

[2026-02-20 05:28:53 TP0] Scheduler hit an exception: Traceback (most recent call last):
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 172, in _capture_graph
    out = run_once_fn()
          ^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 297, in run_once
    ret = self.eagle_worker.draft_forward(forward_batch)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_worker_v2.py", line 431, in draft_forward
    detect_nan(logits_output)
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/spec_utils.py", line 711, in detect_nan
    if torch.any(torch.isnan(logits)):
       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
torch.AcceleratorError: CUDA error: operation not permitted when stream is capturing
Search for `cudaErrorStreamCaptureUnsupported' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 130, in __init__
    self.capture()
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 179, in capture
    CudaGraphRunner.capture(self)
  File "/home/aabbccddwasd/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 526, in capture
    _capture_one_stream()
  File "/home/aabbccddwasd/sglang/python/sglang/srt/model_executor/cuda_graph_runner.py", line 513, in _capture_one_stream
    ) = self.capture_one_batch_size(bs, forward, stream_idx)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 307, in capture_one_batch_size
    out = self._capture_graph(
          ^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 171, in _capture_graph
    with torch.cuda.graph(graph, pool=pool, stream=stream):
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/anaconda3/envs/sgl-test/lib/python3.12/site-packages/torch/cuda/graphs.py", line 265, in __exit__
    self.cuda_graph.capture_end()
  File "/home/aabbccddwasd/anaconda3/envs/sgl-test/lib/python3.12/site-packages/torch/cuda/graphs.py", line 128, in capture_end
    super().capture_end()
torch.AcceleratorError: CUDA error: operation failed due to a previous error during capture
Search for `cudaErrorStreamCaptureInvalidated' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.


During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/aabbccddwasd/sglang/python/sglang/srt/managers/scheduler.py", line 3105, in run_scheduler_process
    scheduler = Scheduler(
                ^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/managers/scheduler.py", line 366, in __init__
    self.init_model_worker()
  File "/home/aabbccddwasd/sglang/python/sglang/srt/managers/scheduler.py", line 563, in init_model_worker
    self.maybe_init_draft_worker()
  File "/home/aabbccddwasd/sglang/python/sglang/srt/managers/scheduler.py", line 559, in maybe_init_draft_worker
    self.draft_worker = DraftWorkerClass(**draft_worker_kwargs)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_worker_v2.py", line 622, in __init__
    self._draft_worker = EagleDraftWorker(
                         ^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_worker_v2.py", line 171, in __init__
    self.init_cuda_graphs()
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_worker_v2.py", line 264, in init_cuda_graphs
    self.cuda_graph_runner = Device2DraftCudaGraphRunner[
                             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/aabbccddwasd/sglang/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py", line 132, in __init__
    raise Exception(
Exception: Capture cuda graph failed: CUDA error: operation failed due to a previous error during capture
Search for `cudaErrorStreamCaptureInvalidated' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Possible solutions:
1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)
2. set --cuda-graph-max-bs to a smaller value (e.g., 16)
3. disable torch compile by not using --enable-torch-compile
4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)
Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose


[2026-02-20 05:28:53] Received sigquit from a child process. It usually means the child failed.

@b8zhong b8zhong closed this Mar 9, 2026
@b8zhong b8zhong deleted the vz/spec-v2-qwen3.5-next branch March 9, 2026 17:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants