Skip to content

[offloader] v2: Hide weight onloading latency via prefetching#29941

Merged
vllm-bot merged 45 commits intovllm-project:mainfrom
minosfuture:offloader
Feb 26, 2026
Merged

[offloader] v2: Hide weight onloading latency via prefetching#29941
vllm-bot merged 45 commits intovllm-project:mainfrom
minosfuture:offloader

Conversation

@minosfuture
Copy link
Copy Markdown
Contributor

@minosfuture minosfuture commented Dec 3, 2025

Purpose

This PR adds CPU weight offloader that hides weight onloading latency by prefetching weights. This saves the performance cost of zero-copy UVA access. This technique was first developed in SGLang for GB200: https://lmsys.org/blog/2025-09-25-gb200-part-2/, and now adapted to support torch.compile and CUDA graph within vLLM in this PR.

Also refactors the offloading to be extensible.

Demonstrated in the trace:

  • H2D is for prefetching next offloaded weights
image

Test Plan

  • deepseek fp4 in 2 GPUs with offloading
  • deepseek fp4 in 4 GPUs without offloading (doesn't fit in 2 GPUs without offloading)

Example serving recipe for GB200:

HF_HUB_OFFLINE=1 NVIDIA_GDRCOPY=1 NVSHMEM_IB_ENABLE_IBGDA=1 VLLM_SKIP_P2P_CHECK=1 \
NCCL_CUMEM_ENABLE=1 NCCL_MNNVL_ENABLE=1 NCCL_NVLS_ENABLE=1 \
VLLM_RANDOMIZE_DP_DUMMY_INPUTS=1 VLLM_USE_FLASHINFER_MOE_FP4=1 \
VLLM_FLASHINFER_MOE_BACKEND=latency VLLM_ENABLE_FUSED_MOE_ACTIVATION_CHUNKING=0 \
VLLM_USE_NCCL_SYMM_MEM=1 VLLM_ENABLE_MOE_DP_CHUNK=0 numactl --cpunodebind=0 \
--membind=0 vllm serve \
nvidia/DeepSeek-R1-0528-NVFP4-v2 \
--attention-config.backend FLASHINFER_MLA \
--attention-config.use_trtllm_ragged_deepseek_prefill=true --kv-cache-dtype fp8 \
--enable-expert-parallel --data-parallel-rpc-port 13345 --max-model-len 2148 \
--disable-uvicorn-access-log --port 8000 --trust-remote-code --async-scheduling \
--disable_custom_all_reduce --disable_nccl_for_dp_synchronization \
--no-enable-prefix-caching --all2all-backend allgather_reducescatter \
--gpu-memory-utilization 0.85 --max-num-batched-tokens 65536 --max-num-seqs 1024 \
--swap-space 16 --data-parallel-size 2 --offload-group-size 2 \
--offload-num-in-group 1 --offload-prefetch-step 1 --offload-params w13_weight \
w2_weight --api-server-count 1 --profiler-config '{"profiler": "torch",
"torch_profiler_dir": "/tmp/traces"}' 2>&1 | tee /tmp/dsr1.log

Test Result

QPS per GPU 18.04 (2 GPUs)

this PR --offload-group-size 2 --offload-num-in-group 1 --offload-prefetch-step 1

============ Serving Benchmark Result ============
Successful requests:                     1024
Failed requests:                         0
Maximum request concurrency:             1024
Benchmark duration (s):                  56.78
Total input tokens:                      2096128
Total generated tokens:                  1024
Request throughput (req/s):              18.04
Output token throughput (tok/s):         18.04
Peak output token throughput (tok/s):    63.00
Peak concurrent requests:                1024.00
Total Token throughput (tok/s):          36937.17
---------------Time to First Token----------------
Mean TTFT (ms):                          30965.61
Median TTFT (ms):                        30190.67
P99 TTFT (ms):                           56685.92
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00
Median TPOT (ms):                        0.00
P99 TPOT (ms):                           0.00
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P99 ITL (ms):                            0.00
==================================================

QPS per GPU 16.9 (4 GPUs)

without --offload-group-size 2 --offload-num-in-group 1 --offload-prefetch-step 1

============ Serving Benchmark Result ============
Successful requests:                     1024
Failed requests:                         0
Maximum request concurrency:             1024
Benchmark duration (s):                  30.28
Total input tokens:                      2096128
Total generated tokens:                  1024
Request throughput (req/s):              33.82
Output token throughput (tok/s):         33.82
Peak output token throughput (tok/s):    64.00
Peak concurrent requests:                1024.00
Total Token throughput (tok/s):          69259.31
---------------Time to First Token----------------
Mean TTFT (ms):                          17335.01
Median TTFT (ms):                        17177.41
P99 TTFT (ms):                           30175.24
-----Time per Output Token (excl. 1st token)------
Mean TPOT (ms):                          0.00
Median TPOT (ms):                        0.00
P99 TPOT (ms):                           0.00
---------------Inter-token Latency----------------
Mean ITL (ms):                           0.00
Median ITL (ms):                         0.00
P99 ITL (ms):                            0.00
==================================================

Accuracy:

local-completions (model=nvidia/DeepSeek-R1-0528-FP4-v2,base_url=http://127.0.0.1:8000/v1/completions,num_concurrent=32), gen_kwargs: (None), limit: 100.0, num_fewshot: None, batch_size: 1

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.98 ± 0.0141
strict-match 5 exact_match 0.98 ± 0.0141

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Ming Yang <minos.future@gmail.com>
@mergify mergify Bot added deepseek Related to DeepSeek models frontend v1 labels Dec 3, 2025
Copy link
Copy Markdown

@chatgpt-codex-connector chatgpt-codex-connector Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Codex has been enabled to automatically review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

When you sign up for Codex through ChatGPT, Codex can also answer questions or update the PR, like "@codex address that feedback".

Comment thread vllm/model_executor/offloader/uva.py Outdated
Signed-off-by: Ming Yang <minos.future@gmail.com>
Comment thread vllm/config/cache.py Outdated
Copy link
Copy Markdown
Member

@ywang96 ywang96 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC this is mostly a RL feature, correct? Maybe @youkaichao can take a look?

@elvircrn
Copy link
Copy Markdown
Contributor

elvircrn commented Dec 5, 2025

Why is TPOT and ITL 0 in the benchmark result?

EDIT:

Ah I see you probably set num output tokens to 1.

@minosfuture
Copy link
Copy Markdown
Contributor Author

IIUC this is mostly a RL feature, correct? Maybe @youkaichao can take a look?

mostly for fitting model onto less GPUs at the moment. But could be useful for weight updating in RL.
pinging @youkaichao for review since you did the first offloading implementation.

Comment thread vllm/model_executor/offloader/uva.py Outdated
Comment thread .buildkite/test_areas/e2e_integration.yaml Outdated
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Copy link
Copy Markdown
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for the changes. I think it is obvious the get_offloader().sync_prev_onload() and get_offloader().join_after_forward() insertions in the model runner are fragile and prone to failure in the future, but I'm not sure how to better structure this. Maybe @benchislett or @LucasWilkinson have strong opinions against this, but we do need this performant feature to land sooner or later

@github-project-automation github-project-automation Bot moved this to Ready in NVIDIA Feb 25, 2026
minosfuture and others added 4 commits February 24, 2026 23:05
Set _instance default to NoopOffloader() so get_offloader() always
returns a valid instance. Log the offloader type in set_offloader()
for visibility into which backend is active.

Signed-off-by: Ming Yang <minos.future@gmail.com>
Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compile pieces lgtm

@vllm-bot vllm-bot merged commit 6831650 into vllm-project:main Feb 26, 2026
69 of 70 checks passed
@github-project-automation github-project-automation Bot moved this from Ready to Done in NVIDIA Feb 26, 2026
@github-project-automation github-project-automation Bot moved this from In review to Done in Large-Scale Serving Feb 26, 2026
haanjack pushed a commit to haanjack/vllm that referenced this pull request Feb 26, 2026
…roject#29941)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Copilot AI pushed a commit to machov/vllm that referenced this pull request Mar 10, 2026
…roject#29941)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
jiangkuaixue123 pushed a commit to jiangkuaixue123/vllm that referenced this pull request Apr 28, 2026
…roject#29941)

Signed-off-by: Ming Yang <minos.future@gmail.com>
Signed-off-by: Michael Goin <mgoin64@gmail.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ci/build deepseek Related to DeepSeek models frontend nvidia ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

Status: Done
Status: Done

Development

Successfully merging this pull request may close these issues.