Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
a3dfa47
offloader v2
minosfuture Dec 2, 2025
9bf4097
cleanup
minosfuture Dec 3, 2025
152af73
cleanup: extract func; move imports
minosfuture Dec 3, 2025
67cd6cc
fix import
minosfuture Dec 3, 2025
c21063e
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Dec 3, 2025
f43b577
address comment: remove legacy code
minosfuture Dec 4, 2025
975e972
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Dec 18, 2025
0ed3570
refactor: move offload config from CacheConfig to OffloadConfig
minosfuture Dec 20, 2025
b07eb3f
refactor: generalize MoE detection for offloading
minosfuture Dec 20, 2025
ac5fb49
cleanup: remove NVFP4 scale params from offload whitelist
minosfuture Dec 20, 2025
80e764e
cleanup: remove dead code in UVAOffloader
minosfuture Dec 20, 2025
9436a66
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Jan 5, 2026
21e0813
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Jan 22, 2026
5bc88d8
[Core] V2 offloader: static buffers and custom ops for torch.compile
minosfuture Jan 25, 2026
719af1b
[Core] V2 offloader: CUDA graph capture support
minosfuture Jan 27, 2026
91c180e
clean up unnecessary code
minosfuture Jan 27, 2026
20db2a1
[Core] V2 offloader: simplify API by using get_offloader() directly
minosfuture Jan 27, 2026
2debd78
[Core] V2 offloader: fix CUDA graph bugs in Eagle and UBatch
minosfuture Jan 27, 2026
20ef9d8
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 19, 2026
f536bc4
minor param order adjustmetn
minosfuture Feb 19, 2026
021acc1
Keep cpu_offload_params on CacheConfig as deprecated field
minosfuture Feb 19, 2026
0525c42
Fix OffloadConfig.compute_hash() returning constant hash
minosfuture Feb 19, 2026
6506706
Fix OffloaderV2 accuracy bug: re-pin CPU storage after process_weight…
minosfuture Feb 20, 2026
55a9934
pre-commit fixes
minosfuture Feb 20, 2026
8478e55
Fix OffloaderV2 crash with dotted parameter names
minosfuture Feb 22, 2026
4520428
Specify removal version for deprecated CacheConfig offload fields
minosfuture Feb 22, 2026
fcaa9da
Use @config decorator alone for OffloadConfig
minosfuture Feb 22, 2026
79ec7d5
Add explicit offload_backend selector with nested sub-configs
minosfuture Feb 22, 2026
5fe6d7c
Rename OffloaderV2 to PrefetchOffloader
minosfuture Feb 22, 2026
08bab8a
Move PrefetchOffloader parameterization from model code to CLI
minosfuture Feb 22, 2026
0d18ec2
Make prefetch custom ops return None instead of tensor
minosfuture Feb 22, 2026
e5d375e
Switch prefetch offload test from DeepSeek-V2-Lite to Llama-3.2-1B-In…
minosfuture Feb 23, 2026
1d3d404
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 23, 2026
3c3ba9b
Remove incorrect comment
minosfuture Feb 23, 2026
96301e3
Use per-layer event sync in eager mode to enable H2D overlap
minosfuture Feb 24, 2026
6c6600d
Add offload_params to LLM interface
minosfuture Feb 24, 2026
ff52cd4
Use torch.finfo instead of creating tensor to get element size
minosfuture Feb 24, 2026
2a77385
Initialize global offloader as None to fail loudly before set_offloader
minosfuture Feb 24, 2026
20696cd
Add nightly e2e test for prefetch offloading with DeepSeek-V2-Lite
minosfuture Feb 24, 2026
f537f21
Fix sync_prev_onload comments in run_fullgraph
minosfuture Feb 24, 2026
18b576e
Update .buildkite/test_areas/e2e_integration.yaml
mgoin Feb 25, 2026
3df4d98
Default global offloader to NoopOffloader and log on set
minosfuture Feb 25, 2026
150dc9f
Merge remote-tracking branch 'minosfuture/offloader' into offloader
minosfuture Feb 25, 2026
9f49711
Merge remote-tracking branch 'origin/main' into offloader
minosfuture Feb 25, 2026
0478a6a
Merge branch 'main' into offloader
mgoin Feb 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions tests/basic_correctness/test_v2_offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Test V2 offloading correctness with DeepSeek V2 model."""

from ..utils import compare_two_settings


def test_v2_offload_deepseek():
"""Test V2 CPU offloading with DeepSeek-V2-Lite.

Compares outputs between:
1. Baseline (no offloading)
2. V2 offloading (group_size=8, num_in_group=2, prefetch_step=1)

This tests the advanced offloading with prefetching on a MoE model.
"""
compare_two_settings(
"deepseek-ai/DeepSeek-V2-Lite",
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could use an fp8 model here to make it faster like RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

serving "RedHatAI/DeepSeek-Coder-V2-Lite-Instruct-FP8" failed at flashinfer autotuning stage on GB200. 😿

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

switched to llama as we support any model now.

[], # Baseline: no offloading
[
# V2 offloading configuration
"--offload-group-size",
"8",
"--offload-num-in-group",
"2",
"--offload-prefetch-step",
"1",
# currently not compatible with torch.compile
"--enforce-eager",
],
)
3 changes: 3 additions & 0 deletions vllm/config/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
)
from vllm.config.multimodal import MultiModalConfig
from vllm.config.observability import ObservabilityConfig
from vllm.config.offload import OffloadConfig
from vllm.config.parallel import EPLBConfig, ParallelConfig
from vllm.config.pooler import PoolerConfig
from vllm.config.profiler import ProfilerConfig
Expand Down Expand Up @@ -77,6 +78,8 @@
"MultiModalConfig",
# From vllm.config.observability
"ObservabilityConfig",
# From vllm.config.offload
"OffloadConfig",
# From vllm.config.parallel
"EPLBConfig",
"ParallelConfig",
Expand Down
3 changes: 3 additions & 0 deletions vllm/config/cache.py
Comment thread
minosfuture marked this conversation as resolved.
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class CacheConfig:
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.

DEPRECATED: This field is deprecated and will be removed in a future
release. Please use OffloadConfig.cpu_offload_gb instead.
"""
calculate_kv_scales: bool = False
"""This enables dynamic calculation of `k_scale` and `v_scale` when
Expand Down
80 changes: 80 additions & 0 deletions vllm/config/offload.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Configuration for model weight offloading."""

from typing import Any

from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass

from vllm.config.utils import config
from vllm.utils.hashing import safe_hash


@config
@dataclass
Comment thread
minosfuture marked this conversation as resolved.
Outdated
class OffloadConfig:
"""Configuration for model weight offloading to CPU.

This controls how model parameters are offloaded to CPU memory to reduce
GPU memory usage, at the cost of additional CPU-GPU transfers during
inference.
"""

cpu_offload_gb: float = Field(default=0, ge=0)
"""The space in GiB to offload to CPU, per GPU. Default is 0, which means
no offloading. Intuitively, this argument can be seen as a virtual way to
increase the GPU memory size. For example, if you have one 24 GB GPU and
set this to 10, virtually you can think of it as a 34 GB GPU. Then you can
load a 13B model with BF16 weight, which requires at least 26GB GPU memory.
Note that this requires fast CPU-GPU interconnect, as part of the model is
loaded from CPU memory to GPU memory on the fly in each model forward pass.
This uses UVA (Unified Virtual Addressing) for zero-copy access.
"""

offload_group_size: int = Field(default=0, ge=0)
"""Advanced CPU offloading (V2): Group every N layers together. Offload last
`offload_num_in_group` layers of each group. Default is 0 (disabled).
Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,...
Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer
latency.
"""

offload_num_in_group: int = Field(default=1, ge=1)
"""Advanced CPU offloading (V2): Number of layers to offload per group.
Must be <= offload_group_size. Default is 1."""

offload_prefetch_step: int = Field(default=1, ge=0)
"""Advanced CPU offloading (V2): Number of layers to prefetch ahead.
Higher values hide more latency but use more GPU memory. Default is 1."""

@model_validator(mode="after")
def validate_offload_config(self) -> "OffloadConfig":
"""Validate that offload_num_in_group <= offload_group_size."""
if (
self.offload_group_size > 0
and self.offload_num_in_group > self.offload_group_size
):
raise ValueError(
f"offload_num_in_group ({self.offload_num_in_group}) must be "
f"<= offload_group_size ({self.offload_group_size})"
)
return self
Comment thread
minosfuture marked this conversation as resolved.

def compute_hash(self) -> str:
"""
WARNING: Whenever a new field is added to this config,
ensure that it is included in the factors list if
it affects the computation graph.

Provide a hash that uniquely identifies all the configs
that affect the structure of the computation
graph from input ids/embeddings to the final hidden states,
excluding anything before input ids/embeddings and after
the final hidden states.
"""
# Offload settings don't affect the computation graph structure,
# only the memory layout and transfer patterns.
factors: list[Any] = []
hash_str = safe_hash(str(factors).encode(), usedforsecurity=False).hexdigest()
return hash_str
7 changes: 7 additions & 0 deletions vllm/config/vllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from .lora import LoRAConfig
from .model import ModelConfig
from .observability import ObservabilityConfig
from .offload import OffloadConfig
from .parallel import ParallelConfig
from .profiler import ProfilerConfig
from .scheduler import SchedulerConfig
Expand Down Expand Up @@ -194,6 +195,8 @@ class VllmConfig:
"""Device configuration."""
load_config: LoadConfig = Field(default_factory=LoadConfig)
"""Load configuration."""
offload_config: OffloadConfig = Field(default_factory=OffloadConfig)
"""Model weight offloading configuration."""
attention_config: AttentionConfig = Field(default_factory=AttentionConfig)
"""Attention configuration."""
lora_config: LoRAConfig | None = None
Expand Down Expand Up @@ -285,6 +288,10 @@ def compute_hash(self) -> str:
vllm_factors.append(self.load_config.compute_hash())
else:
vllm_factors.append("None")
if self.offload_config:
vllm_factors.append(self.offload_config.compute_hash())
else:
vllm_factors.append("None")
if self.attention_config:
vllm_factors.append(self.attention_config.compute_hash())
else:
Expand Down
36 changes: 33 additions & 3 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
ModelConfig,
MultiModalConfig,
ObservabilityConfig,
OffloadConfig,
ParallelConfig,
PoolerConfig,
ProfilerConfig,
Expand Down Expand Up @@ -434,7 +435,10 @@ class EngineArgs:
disable_sliding_window: bool = ModelConfig.disable_sliding_window
disable_cascade_attn: bool = ModelConfig.disable_cascade_attn
swap_space: float = CacheConfig.swap_space
cpu_offload_gb: float = CacheConfig.cpu_offload_gb
cpu_offload_gb: float = OffloadConfig.cpu_offload_gb
offload_group_size: int = OffloadConfig.offload_group_size
offload_num_in_group: int = OffloadConfig.offload_num_in_group
offload_prefetch_step: int = OffloadConfig.offload_prefetch_step
gpu_memory_utilization: float = CacheConfig.gpu_memory_utilization
kv_cache_memory_bytes: int | None = CacheConfig.kv_cache_memory_bytes
max_num_batched_tokens: int | None = None
Expand Down Expand Up @@ -912,7 +916,6 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
cache_group.add_argument(
"--prefix-caching-hash-algo", **cache_kwargs["prefix_caching_hash_algo"]
)
cache_group.add_argument("--cpu-offload-gb", **cache_kwargs["cpu_offload_gb"])
cache_group.add_argument(
"--calculate-kv-scales", **cache_kwargs["calculate_kv_scales"]
)
Expand All @@ -935,6 +938,25 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
"--kv-offloading-backend", **cache_kwargs["kv_offloading_backend"]
)

# Model weight offload related configs
offload_kwargs = get_kwargs(OffloadConfig)
offload_group = parser.add_argument_group(
title="OffloadConfig",
description=OffloadConfig.__doc__,
)
offload_group.add_argument(
"--cpu-offload-gb", **offload_kwargs["cpu_offload_gb"]
)
offload_group.add_argument(
"--offload-group-size", **offload_kwargs["offload_group_size"]
)
offload_group.add_argument(
"--offload-num-in-group", **offload_kwargs["offload_num_in_group"]
)
offload_group.add_argument(
"--offload-prefetch-step", **offload_kwargs["offload_prefetch_step"]
)

# Multimodal related configs
multimodal_kwargs = get_kwargs(MultiModalConfig)
multimodal_group = parser.add_argument_group(
Expand Down Expand Up @@ -1384,7 +1406,6 @@ def create_engine_config(
sliding_window=sliding_window,
enable_prefix_caching=self.enable_prefix_caching,
prefix_caching_hash_algo=self.prefix_caching_hash_algo,
cpu_offload_gb=self.cpu_offload_gb,
calculate_kv_scales=self.calculate_kv_scales,
kv_sharing_fast_prefill=self.kv_sharing_fast_prefill,
mamba_cache_dtype=self.mamba_cache_dtype,
Expand Down Expand Up @@ -1715,13 +1736,22 @@ def create_engine_config(
compilation_config.max_cudagraph_capture_size = (
self.max_cudagraph_capture_size
)

offload_config = OffloadConfig(
cpu_offload_gb=self.cpu_offload_gb,
offload_group_size=self.offload_group_size,
offload_num_in_group=self.offload_num_in_group,
offload_prefetch_step=self.offload_prefetch_step,
)

config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
parallel_config=parallel_config,
scheduler_config=scheduler_config,
device_config=device_config,
load_config=load_config,
offload_config=offload_config,
attention_config=attention_config,
lora_config=lora_config,
speculative_config=speculative_config,
Expand Down
14 changes: 14 additions & 0 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,14 @@ class LLM:
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
transfer for every forward pass.
offload_group_size: Advanced CPU offloading: Group every N layers
together. Offload last `offload_num_in_group` layers of each group.
Default is 0 (disabled).
offload_num_in_group: Advanced CPU offloading: Number of layers to
offload per group. Default is 1.
offload_prefetch_step: Advanced CPU offloading: Number of layers to
prefetch ahead. Higher values hide more latency but use more GPU
memory. Default is 1.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like we are missing offload_params here in the LLM interface

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added. Thanks!

enforce_eager: Whether to enforce eager execution. If True, we will
disable CUDA graph and always execute the model in eager mode.
If False, we will use CUDA graph and eager execution in hybrid.
Expand Down Expand Up @@ -208,6 +216,9 @@ def __init__(
gpu_memory_utilization: float = 0.9,
swap_space: float = 4,
cpu_offload_gb: float = 0,
offload_group_size: int = 0,
offload_num_in_group: int = 1,
offload_prefetch_step: int = 1,
enforce_eager: bool = False,
disable_custom_all_reduce: bool = False,
hf_token: bool | str | None = None,
Expand Down Expand Up @@ -316,6 +327,9 @@ def _make_config(value: Any, cls: type[_R]) -> _R:
kv_cache_memory_bytes=kv_cache_memory_bytes,
swap_space=swap_space,
cpu_offload_gb=cpu_offload_gb,
offload_group_size=offload_group_size,
offload_num_in_group=offload_num_in_group,
offload_prefetch_step=offload_prefetch_step,
enforce_eager=enforce_eager,
disable_custom_all_reduce=disable_custom_all_reduce,
hf_token=hf_token,
Expand Down
2 changes: 2 additions & 0 deletions vllm/model_executor/layers/fused_moe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE,
FusedMoeWeightScaleSupported,
find_fused_moe_submodule,
)
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat,
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_config() -> dict[str, Any] | None:
"RoutingMethodType",
"SharedFusedMoE",
"activation_without_mul",
"find_fused_moe_submodule",
"override_config",
"get_config",
]
Expand Down
25 changes: 25 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,3 +2175,28 @@ def moe_forward_shared_fake(
# Mark the FusedMoE weight_loader as supporting MoE-specific parameters
# to avoid expensive runtime reflection in model loading code
FusedMoE.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]


def find_fused_moe_submodule(module: torch.nn.Module) -> torch.nn.Module:
"""Find a FusedMoE submodule for offloading, or return the module itself.

Searches module attributes for instances of FusedMoE (or subclasses like
SharedFusedMoE).

Args:
module: The module to search within (typically layer.mlp).

Returns:
The first FusedMoE instance found, or the original module if none found.
"""
for attr_name in dir(module):
if attr_name.startswith("_"):
continue
try:
attr = getattr(module, attr_name, None)
except Exception:
continue
if isinstance(attr, FusedMoE):
return attr

return module
20 changes: 19 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,10 @@
from vllm.logger import init_logger
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.layers.fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe import (
SharedFusedMoE,
find_fused_moe_submodule,
)
from vllm.model_executor.layers.layernorm import LayerNorm, RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
Expand Down Expand Up @@ -1274,6 +1277,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
vllm_config, prefix, topk_indices_buffer=topk_indices_buffer
),
prefix=f"{prefix}.layers",
offloader_kwargs=dict(
# Extract the MLP submodule - for MoE layers, go deeper to the experts
submodule_accessor=lambda layer: find_fused_moe_submodule(layer.mlp),
# Specify which parameters to offload
whitelist_param_names_creator=lambda module: (
[
# Core MoE expert weights
"w13_weight",
"w2_weight",
]
# Only offload from MoE experts (SharedFusedMoE/FusedMoE)
if hasattr(module, "w13_weight")
else []
),
),
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to do it for every model? this looks intrusive 🤔

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is highly targeted optimization as it requires carefully overlapping weight onloading (memcpy) and computation. So it is impacted by model weight size, CPU<>GPU bandwidth, computation latency (per batch size). So it should be configured at least model level.

Or, maybe this should be configurable via cli args. WDYT?

Copy link
Copy Markdown
Contributor

@wzhao18 wzhao18 Feb 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the current offloading, we use cpu_offload_params CLI arg to support offloading MoE weights only. It is not perfect (as we need to figure out parameter names which differ across models), but it stays non-intrusive and can apply to all models without code modification. Could you see if that could be applied similarly for the V2?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated as @wzhao18 suggested. Feeling much better about this now. cc @mgoin

)

if get_pp_group().is_last_rank:
Expand Down
Loading