Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
a3dfa47
offloader v2
minosfuture Dec 2, 2025
9bf4097
cleanup
minosfuture Dec 3, 2025
152af73
cleanup: extract func; move imports
minosfuture Dec 3, 2025
67cd6cc
fix import
minosfuture Dec 3, 2025
c21063e
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Dec 3, 2025
f43b577
address comment: remove legacy code
minosfuture Dec 4, 2025
975e972
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Dec 18, 2025
0ed3570
refactor: move offload config from CacheConfig to OffloadConfig
minosfuture Dec 20, 2025
b07eb3f
refactor: generalize MoE detection for offloading
minosfuture Dec 20, 2025
ac5fb49
cleanup: remove NVFP4 scale params from offload whitelist
minosfuture Dec 20, 2025
80e764e
cleanup: remove dead code in UVAOffloader
minosfuture Dec 20, 2025
9436a66
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Jan 5, 2026
21e0813
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Jan 22, 2026
5bc88d8
[Core] V2 offloader: static buffers and custom ops for torch.compile
minosfuture Jan 25, 2026
719af1b
[Core] V2 offloader: CUDA graph capture support
minosfuture Jan 27, 2026
91c180e
clean up unnecessary code
minosfuture Jan 27, 2026
20db2a1
[Core] V2 offloader: simplify API by using get_offloader() directly
minosfuture Jan 27, 2026
2debd78
[Core] V2 offloader: fix CUDA graph bugs in Eagle and UBatch
minosfuture Jan 27, 2026
20ef9d8
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 19, 2026
f536bc4
minor param order adjustmetn
minosfuture Feb 19, 2026
021acc1
Keep cpu_offload_params on CacheConfig as deprecated field
minosfuture Feb 19, 2026
0525c42
Fix OffloadConfig.compute_hash() returning constant hash
minosfuture Feb 19, 2026
6506706
Fix OffloaderV2 accuracy bug: re-pin CPU storage after process_weight…
minosfuture Feb 20, 2026
55a9934
pre-commit fixes
minosfuture Feb 20, 2026
8478e55
Fix OffloaderV2 crash with dotted parameter names
minosfuture Feb 22, 2026
4520428
Specify removal version for deprecated CacheConfig offload fields
minosfuture Feb 22, 2026
fcaa9da
Use @config decorator alone for OffloadConfig
minosfuture Feb 22, 2026
79ec7d5
Add explicit offload_backend selector with nested sub-configs
minosfuture Feb 22, 2026
5fe6d7c
Rename OffloaderV2 to PrefetchOffloader
minosfuture Feb 22, 2026
08bab8a
Move PrefetchOffloader parameterization from model code to CLI
minosfuture Feb 22, 2026
0d18ec2
Make prefetch custom ops return None instead of tensor
minosfuture Feb 22, 2026
e5d375e
Switch prefetch offload test from DeepSeek-V2-Lite to Llama-3.2-1B-In…
minosfuture Feb 23, 2026
1d3d404
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 23, 2026
3c3ba9b
Remove incorrect comment
minosfuture Feb 23, 2026
96301e3
Use per-layer event sync in eager mode to enable H2D overlap
minosfuture Feb 24, 2026
6c6600d
Add offload_params to LLM interface
minosfuture Feb 24, 2026
ff52cd4
Use torch.finfo instead of creating tensor to get element size
minosfuture Feb 24, 2026
2a77385
Initialize global offloader as None to fail loudly before set_offloader
minosfuture Feb 24, 2026
20696cd
Add nightly e2e test for prefetch offloading with DeepSeek-V2-Lite
minosfuture Feb 24, 2026
f537f21
Fix sync_prev_onload comments in run_fullgraph
minosfuture Feb 24, 2026
18b576e
Update .buildkite/test_areas/e2e_integration.yaml
mgoin Feb 25, 2026
3df4d98
Default global offloader to NoopOffloader and log on set
minosfuture Feb 25, 2026
150dc9f
Merge remote-tracking branch 'minosfuture/offloader' into offloader
minosfuture Feb 25, 2026
9f49711
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 25, 2026
0478a6a
Merge branch 'main' into offloader
mgoin Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
#!/usr/bin/env bash
set -euxo pipefail

# Nightly e2e test for prefetch offloading with a MoE model.
# Runs DeepSeek-V2-Lite with prefetch offloading of MoE expert weights
# and validates GSM8K accuracy matches baseline (no offloading).
#
# args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT]
THRESHOLD=${1:-0.25}
NUM_Q=${2:-1319}
PORT=${3:-8030}
OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled}
mkdir -p "${OUT_DIR}"

wait_for_server() {
local port=$1
timeout 600 bash -c '
until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do
sleep 1
done'
}

MODEL="deepseek-ai/DeepSeek-V2-Lite"

cleanup() {
if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then
kill "${SERVER_PID}" 2>/dev/null || true
for _ in {1..20}; do
kill -0 "${SERVER_PID}" 2>/dev/null || break
sleep 0.5
done
kill -9 "${SERVER_PID}" 2>/dev/null || true
fi
}
trap cleanup EXIT

vllm serve "$MODEL" \
--max-model-len 2048 \
--offload-group-size 8 \
--offload-num-in-group 2 \
--offload-prefetch-step 1 \
--offload-params w13_weight w2_weight \
--port "$PORT" &
SERVER_PID=$!
wait_for_server "$PORT"

TAG=$(echo "$MODEL" | tr '/: \\n' '_____')
OUT="${OUT_DIR}/${TAG}_prefetch_offload.json"
python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port "$PORT" --num-questions "${NUM_Q}" --save-results "${OUT}"
python3 - <<PY
import json; acc=json.load(open('${OUT}'))['accuracy']
print(f"${MODEL} prefetch_offload: accuracy {acc:.3f}")
assert acc >= ${THRESHOLD}, f"${MODEL} prefetch_offload accuracy {acc}"
PY

cleanup
SERVER_PID=
9 changes: 9 additions & 0 deletions .buildkite/test_areas/e2e_integration.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@ steps:
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/qwen30b_a3b_fp8_block_ep_eplb.sh 0.8 200 8020 2 1

- label: DeepSeek V2-Lite Prefetch Offload Accuracy (H100)
timeout_in_minutes: 60
device: h100
optional: true
num_devices: 1
working_dir: "/vllm-workspace"
commands:
- bash .buildkite/scripts/scheduled_integration_test/deepseek_v2_lite_prefetch_offload.sh 0.25 200 8030
33 changes: 33 additions & 0 deletions tests/basic_correctness/test_prefetch_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test prefetch offloading correctness with Llama model."""

from ..utils import compare_two_settings


def test_prefetch_offload_llama():
"""Test prefetch CPU offloading with Llama-3.2-1B-Instruct.

Compares outputs between:
1. Baseline (no offloading)
2. Prefetch offloading (group_size=8, num_in_group=2, prefetch_step=1)

This tests prefetching-based offloading on a dense model.
"""
compare_two_settings(
"meta-llama/Llama-3.2-1B-Instruct",
[
# Prefetch offloading configuration
"--offload-group-size",
"8",
"--offload-num-in-group",
"2",
"--offload-prefetch-step",
"1",
# Selective offloading: only MLP weights
"--offload-params",
"gate_up_proj",
"down_proj",
],
[], # Baseline: no offloading
)
Comment thread
mgoin marked this conversation as resolved.
14 changes: 14 additions & 0 deletions vllm/compilation/cuda_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id
from vllm.forward_context import BatchDescriptor, get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.offloader.base import get_offloader
from vllm.platforms import current_platform
from vllm.utils.torch_utils import current_stream, weak_ref_tensors

Expand Down Expand Up @@ -265,6 +266,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
set_graph_pool_id(self.graph_pool)
else:
set_graph_pool_id(current_platform.graph_pool_handle())

# Sync offloader's copy stream before capture.
# Ensure any pre-capture prefetches from offloader are complete.
get_offloader().sync_prev_onload()
Comment on lines +270 to +272
Copy link
Copy Markdown
Collaborator

@zou3519 zou3519 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison and/or @BoyuanFeng, could you take a look at this too please? These look reasonable to me but I'm new to this

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to validate: the same sync, occurs at runtime, prior to graph replay ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I updated the comment in run_fullgraph to clarify this a bit more.

The sync_prev_onload before replay is only needed for when vllm transitions from non-cuda-graph execution to cuda-graph execution.

  1. non-cg -> non-cg: first layers of next forward would wait on correct copy events
  2. non-cg -> cg: first layers of next forward doesn't wait as event cannot cross capture boundary, so we need this sync
  3. cg -> non-cg: we have calledjoin_after_forward where event sync is captured (copy stream is joined)
  4. cg -> cg: same, we have called join _after_forward where event sync is captured (copy stream is joined)


# mind-exploding: carefully manage the reference and memory.
with torch.cuda.graph(
cudagraph,
Expand All @@ -273,6 +279,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
):
# `output` is managed by pytorch's cudagraph pool
output = self.runnable(*args, **kwargs)
# Join offloader's copy stream after forward to avoid
# unjoined stream error. The last layer's start_prefetch
# forks copy_stream, but wait_prefetch only happens in
# the next forward pass.
get_offloader().join_after_forward()
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

where is the copy_stream branched from the current stream?

CUDAGraph requires side stream copy_stream to branch off and join back to cudagraph capture stream within the cudagraph capture region:

with torch.cuda.graph(g):
    copy_stream.wait_stream(torch.cuda.current_stream())    # branch the copy_stream from the current_stream
    with torch.cuda.stream(copy_stream):
        x_cpu.copy_(x_cuda, non_blocking=True)

    # any computation
    out = y_cuda + 1

    torch.cuda.current_stream().wait_stream(copy_stream)  # join the copy_stream

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi Boyuan, it's in start_onload_to_static, which records _copy_done_event that this join_after_forward waits on before finishing capture.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks. for future-proof, could we expose an api to call it in cuda_graph.py for symmetry?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's asymmetric with the current implementation: events are recorded at per-layer level, and then waited on by the later layers that need the weights. But unfortunately, CUDA graph require events to be joined before capture ends, and this is why we need join_after_forward that waits on any open events that would've been sync'ed by following layers outside the current CUDA graph capture scope.
It is indeed confusing that we only have the sync here. Lemme at least clarify a bit more in the comments for now.

if self.cudagraph_options.weak_ref_output:
# by converting it to weak ref,
# the original `output` will immediately be released
Expand Down Expand Up @@ -305,5 +316,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None:
f"got {new_input_addresses}"
)

# Sync offloader before replay - ensures any external dependencies
# from pre-capture prefetches are satisfied.
get_offloader().sync_prev_onload()
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this is runtime overhead on hot path. when offloader is not on, worth making this cheaper ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's no op when offload is not configured. see BaseOffloader and NoopOffloader.

entry.cudagraph.replay()
return entry.output
11 changes: 11 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,12 @@
)
from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.offload import (
OffloadBackend,
OffloadConfig,
PrefetchOffloadConfig,
UVAOffloadConfig,
)
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
Expand Down Expand Up @@ -85,6 +91,11 @@
"MultiModalConfig",
# From vllm.config.observability
"ObservabilityConfig",
# From vllm.config.offload
"OffloadBackend",
"OffloadConfig",
"PrefetchOffloadConfig",
"UVAOffloadConfig",
# From vllm.config.parallel
"EPLBConfig",
"ParallelConfig",
Expand Down
16 changes: 7 additions & 9 deletions vllm/config/cache.py
Comment thread
minosfuture marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,15 @@ class CacheConfig:
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.

DEPRECATED: This field is deprecated and will be removed in v0.16.
Please use OffloadConfig.uva.cpu_offload_gb instead.
"""
cpu_offload_params: set[str] = Field(default_factory=set)
""" The set of parameter name segments to target for CPU offloading.
Unmatched parameters are not offloaded. If this set is empty, parameters
are offloaded non-selectively until the memory limit defined by
`cpu_offload_gb` is reached.
Examples:
- For parameter name "mlp.experts.w2_weight":
- "experts" or "experts.w2_weight" will match.
- "expert" or "w2" will NOT match (must be exact segments).
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
"""The set of parameter name segments to target for CPU offloading.

DEPRECATED: This field is deprecated and will be removed in v0.16.
Please use OffloadConfig.uva.cpu_offload_params instead.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
Expand Down
153 changes: 153 additions & 0 deletions vllm/config/offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Configuration for model weight offloading."""

import warnings
from typing import Literal

from pydantic import Field, model_validator

from vllm.config.utils import config

OffloadBackend = Literal["auto", "uva", "prefetch"]


@config
class UVAOffloadConfig:
"""Configuration for UVA (Unified Virtual Addressing) CPU offloading.

Uses zero-copy access from CPU-pinned memory. Simple but requires
fast CPU-GPU interconnect.
"""

cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
This uses UVA (Unified Virtual Addressing) for zero-copy access.
"""

cpu_offload_params: set[str] = Field(default_factory=set)
"""The set of parameter name segments to target for CPU offloading.
Unmatched parameters are not offloaded. If this set is empty, parameters
are offloaded non-selectively until the memory limit defined by
`cpu_offload_gb` is reached.
Examples:
- For parameter name "mlp.experts.w2_weight":
- "experts" or "experts.w2_weight" will match.
- "expert" or "w2" will NOT match (must be exact segments).
This allows distinguishing parameters like "w2_weight" and "w2_weight_scale".
"""


@config
class PrefetchOffloadConfig:
"""Configuration for prefetch-based CPU offloading.

Groups layers and uses async H2D prefetch to hide transfer latency.
"""

offload_group_size: int = Field(default=0, ge=0)
"""Group every N layers together. Offload last `offload_num_in_group`
layers of each group. Default is 0 (disabled).
Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,...
Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer
latency.
"""

offload_num_in_group: int = Field(default=1, ge=1)
"""Number of layers to offload per group.
Must be <= offload_group_size. Default is 1."""

offload_prefetch_step: int = Field(default=1, ge=0)
"""Number of layers to prefetch ahead.
Higher values hide more latency but use more GPU memory. Default is 1."""

offload_params: set[str] = Field(default_factory=set)
"""The set of parameter name segments to target for prefetch offloading.
Unmatched parameters are not offloaded. If this set is empty, ALL
parameters of each offloaded layer are offloaded.
Uses segment matching: "w13_weight" matches "mlp.experts.w13_weight"
but not "mlp.experts.w13_weight_scale".
"""


@config
class OffloadConfig:
"""Configuration for model weight offloading to reduce GPU memory usage."""

offload_backend: OffloadBackend = "auto"
"""The backend for weight offloading. Options:
- "auto": Selects based on which sub-config has non-default values
(prefetch if offload_group_size > 0, uva if cpu_offload_gb > 0).
- "uva": UVA (Unified Virtual Addressing) zero-copy offloading.
- "prefetch": Async prefetch with group-based layer offloading.
"""

uva: UVAOffloadConfig = Field(default_factory=UVAOffloadConfig)
"""Parameters for UVA offloading backend."""

prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig)
"""Parameters for prefetch offloading backend."""

@model_validator(mode="after")
def validate_offload_config(self) -> "OffloadConfig":
"""Validate offload configuration constraints."""
if self.offload_backend == "prefetch" or self.prefetch.offload_group_size > 0:
if self.prefetch.offload_num_in_group > self.prefetch.offload_group_size:
raise ValueError(
f"offload_num_in_group ({self.prefetch.offload_num_in_group})"
f" must be <= offload_group_size"
f" ({self.prefetch.offload_group_size})"
)
if self.prefetch.offload_prefetch_step < 1:
raise ValueError(
f"offload_prefetch_step"
f" ({self.prefetch.offload_prefetch_step})"
f" must be >= 1 when prefetch offloading is enabled"
f" (offload_group_size > 0)"
)

# Warn if both backends have non-default values
uva_active = self.uva.cpu_offload_gb > 0
prefetch_active = self.prefetch.offload_group_size > 0
if self.offload_backend == "uva" and prefetch_active:
warnings.warn(
"Prefetch offload fields are set but offload_backend='uva'. "
"Prefetch settings will be ignored.",
stacklevel=2,
)
elif self.offload_backend == "prefetch" and uva_active:
warnings.warn(
"UVA offload fields are set but offload_backend='prefetch'. "
"UVA settings will be ignored.",
stacklevel=2,
)
elif self.offload_backend == "auto" and uva_active and prefetch_active:
warnings.warn(
"Both UVA and prefetch offload fields are set with "
"offload_backend='auto'. Prefetch backend will be selected. "
"Set offload_backend explicitly to suppress this warning.",
stacklevel=2,
)
return self
Comment thread
minosfuture marked this conversation as resolved.

def compute_hash(self) -> str:
"""
Provide a hash that uniquely identifies all the offload configs.

All fields are included because PrefetchOffloader patches module
forwards and inserts custom ops (wait_prefetch, start_prefetch)
into the computation graph. Changing any offload setting can
alter which layers are hooked and how prefetch indices are
computed, so the compilation cache must distinguish them.
"""
from vllm.config.utils import get_hash_factors, hash_factors

factors = get_hash_factors(self, ignored_factors=set())
hash_str = hash_factors(factors)
return hash_str
7 changes: 7 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
Expand Down Expand Up @@ -259,6 +260,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration."""
offload_config: OffloadConfig = Field(default_factory=OffloadConfig)
"""Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
kernel_config: KernelConfig = Field(default_factory=KernelConfig)
Expand Down Expand Up @@ -361,6 +364,10 @@ def compute_hash(self) -> str:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.offload_config:
vllm_factors.append(self.offload_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
Expand Down
Loading