-
-
Notifications
You must be signed in to change notification settings - Fork 16.5k
[offloader] v2: Hide weight onloading latency via prefetching #29941
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 41 commits
a3dfa47
9bf4097
152af73
67cd6cc
c21063e
f43b577
975e972
0ed3570
b07eb3f
ac5fb49
80e764e
9436a66
21e0813
5bc88d8
719af1b
91c180e
20db2a1
2debd78
20ef9d8
f536bc4
021acc1
0525c42
6506706
55a9934
8478e55
4520428
fcaa9da
79ec7d5
5fe6d7c
08bab8a
0d18ec2
e5d375e
1d3d404
3c3ba9b
96301e3
6c6600d
ff52cd4
2a77385
20696cd
f537f21
18b576e
3df4d98
150dc9f
9f49711
0478a6a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| #!/usr/bin/env bash | ||
| set -euxo pipefail | ||
|
|
||
| # Nightly e2e test for prefetch offloading with a MoE model. | ||
| # Runs DeepSeek-V2-Lite with prefetch offloading of MoE expert weights | ||
| # and validates GSM8K accuracy matches baseline (no offloading). | ||
| # | ||
| # args: [THRESHOLD] [NUM_QUESTIONS] [START_PORT] | ||
| THRESHOLD=${1:-0.25} | ||
| NUM_Q=${2:-1319} | ||
| PORT=${3:-8030} | ||
| OUT_DIR=${OUT_DIR:-/tmp/vllm-scheduled} | ||
| mkdir -p "${OUT_DIR}" | ||
|
|
||
| wait_for_server() { | ||
| local port=$1 | ||
| timeout 600 bash -c ' | ||
| until curl -sf "http://127.0.0.1:'"$port"'/health" > /dev/null; do | ||
| sleep 1 | ||
| done' | ||
| } | ||
|
|
||
| MODEL="deepseek-ai/DeepSeek-V2-Lite" | ||
|
|
||
| cleanup() { | ||
| if [[ -n "${SERVER_PID:-}" ]] && kill -0 "${SERVER_PID}" 2>/dev/null; then | ||
| kill "${SERVER_PID}" 2>/dev/null || true | ||
| for _ in {1..20}; do | ||
| kill -0 "${SERVER_PID}" 2>/dev/null || break | ||
| sleep 0.5 | ||
| done | ||
| kill -9 "${SERVER_PID}" 2>/dev/null || true | ||
| fi | ||
| } | ||
| trap cleanup EXIT | ||
|
|
||
| vllm serve "$MODEL" \ | ||
| --max-model-len 2048 \ | ||
| --offload-group-size 8 \ | ||
| --offload-num-in-group 2 \ | ||
| --offload-prefetch-step 1 \ | ||
| --offload-params w13_weight w2_weight \ | ||
| --port "$PORT" & | ||
| SERVER_PID=$! | ||
| wait_for_server "$PORT" | ||
|
|
||
| TAG=$(echo "$MODEL" | tr '/: \\n' '_____') | ||
| OUT="${OUT_DIR}/${TAG}_prefetch_offload.json" | ||
| python3 tests/evals/gsm8k/gsm8k_eval.py --host http://127.0.0.1 --port "$PORT" --num-questions "${NUM_Q}" --save-results "${OUT}" | ||
| python3 - <<PY | ||
| import json; acc=json.load(open('${OUT}'))['accuracy'] | ||
| print(f"${MODEL} prefetch_offload: accuracy {acc:.3f}") | ||
| assert acc >= ${THRESHOLD}, f"${MODEL} prefetch_offload accuracy {acc}" | ||
| PY | ||
|
|
||
| cleanup | ||
| SERVER_PID= |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Test prefetch offloading correctness with Llama model.""" | ||
|
|
||
| from ..utils import compare_two_settings | ||
|
|
||
|
|
||
| def test_prefetch_offload_llama(): | ||
| """Test prefetch CPU offloading with Llama-3.2-1B-Instruct. | ||
|
|
||
| Compares outputs between: | ||
| 1. Baseline (no offloading) | ||
| 2. Prefetch offloading (group_size=8, num_in_group=2, prefetch_step=1) | ||
|
|
||
| This tests prefetching-based offloading on a dense model. | ||
| """ | ||
| compare_two_settings( | ||
| "meta-llama/Llama-3.2-1B-Instruct", | ||
| [ | ||
| # Prefetch offloading configuration | ||
| "--offload-group-size", | ||
| "8", | ||
| "--offload-num-in-group", | ||
| "2", | ||
| "--offload-prefetch-step", | ||
| "1", | ||
| # Selective offloading: only MLP weights | ||
| "--offload-params", | ||
| "gate_up_proj", | ||
| "down_proj", | ||
| ], | ||
| [], # Baseline: no offloading | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,7 @@ | |
| from vllm.distributed.device_communicators.pynccl_allocator import set_graph_pool_id | ||
| from vllm.forward_context import BatchDescriptor, get_forward_context | ||
| from vllm.logger import init_logger | ||
| from vllm.model_executor.offloader.base import get_offloader | ||
| from vllm.platforms import current_platform | ||
| from vllm.utils.torch_utils import current_stream, weak_ref_tensors | ||
|
|
||
|
|
@@ -265,6 +266,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: | |
| set_graph_pool_id(self.graph_pool) | ||
| else: | ||
| set_graph_pool_id(current_platform.graph_pool_handle()) | ||
|
|
||
| # Sync offloader's copy stream before capture. | ||
| # Ensure any pre-capture prefetches from offloader are complete. | ||
| get_offloader().sync_prev_onload() | ||
|
Comment on lines
+270
to
+272
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @eellison and/or @BoyuanFeng, could you take a look at this too please? These look reasonable to me but I'm new to this
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to validate: the same sync, occurs at runtime, prior to graph replay ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. I updated the comment in The
|
||
|
|
||
| # mind-exploding: carefully manage the reference and memory. | ||
| with torch.cuda.graph( | ||
| cudagraph, | ||
|
|
@@ -273,6 +279,11 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: | |
| ): | ||
| # `output` is managed by pytorch's cudagraph pool | ||
| output = self.runnable(*args, **kwargs) | ||
| # Join offloader's copy stream after forward to avoid | ||
| # unjoined stream error. The last layer's start_prefetch | ||
| # forks copy_stream, but wait_prefetch only happens in | ||
| # the next forward pass. | ||
| get_offloader().join_after_forward() | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. where is the CUDAGraph requires side stream with torch.cuda.graph(g):
copy_stream.wait_stream(torch.cuda.current_stream()) # branch the copy_stream from the current_stream
with torch.cuda.stream(copy_stream):
x_cpu.copy_(x_cuda, non_blocking=True)
# any computation
out = y_cuda + 1
torch.cuda.current_stream().wait_stream(copy_stream) # join the copy_stream
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. hi Boyuan, it's in
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. thanks. for future-proof, could we expose an api to call it in
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's asymmetric with the current implementation: events are recorded at per-layer level, and then waited on by the later layers that need the weights. But unfortunately, CUDA graph require events to be joined before capture ends, and this is why we need |
||
| if self.cudagraph_options.weak_ref_output: | ||
| # by converting it to weak ref, | ||
| # the original `output` will immediately be released | ||
|
|
@@ -305,5 +316,8 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: | |
| f"got {new_input_addresses}" | ||
| ) | ||
|
|
||
| # Sync offloader before replay - ensures any external dependencies | ||
| # from pre-capture prefetches are satisfied. | ||
| get_offloader().sync_prev_onload() | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: this is runtime overhead on hot path. when offloader is not on, worth making this cheaper ?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's no op when offload is not configured. see BaseOffloader and NoopOffloader. |
||
| entry.cudagraph.replay() | ||
| return entry.output | ||
|
minosfuture marked this conversation as resolved.
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,153 @@ | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
| # SPDX-FileCopyrightText: Copyright contributors to the vLLM project | ||
| """Configuration for model weight offloading.""" | ||
|
|
||
| import warnings | ||
| from typing import Literal | ||
|
|
||
| from pydantic import Field, model_validator | ||
|
|
||
| from vllm.config.utils import config | ||
|
|
||
| OffloadBackend = Literal["auto", "uva", "prefetch"] | ||
|
|
||
|
|
||
| @config | ||
| class UVAOffloadConfig: | ||
| """Configuration for UVA (Unified Virtual Addressing) CPU offloading. | ||
|
|
||
| Uses zero-copy access from CPU-pinned memory. Simple but requires | ||
| fast CPU-GPU interconnect. | ||
| """ | ||
|
|
||
| cpu_offload_gb: float = Field(default=0, ge=0) | ||
| """The space in GiB to offload to CPU, per GPU. Default is 0, which means | ||
| no offloading. Intuitively, this argument can be seen as a virtual way to | ||
| increase the GPU memory size. For example, if you have one 24 GB GPU and | ||
| set this to 10, virtually you can think of it as a 34 GB GPU. Then you can | ||
| load a 13B model with BF16 weight, which requires at least 26GB GPU memory. | ||
| Note that this requires fast CPU-GPU interconnect, as part of the model is | ||
| loaded from CPU memory to GPU memory on the fly in each model forward pass. | ||
| This uses UVA (Unified Virtual Addressing) for zero-copy access. | ||
| """ | ||
|
|
||
| cpu_offload_params: set[str] = Field(default_factory=set) | ||
| """The set of parameter name segments to target for CPU offloading. | ||
| Unmatched parameters are not offloaded. If this set is empty, parameters | ||
| are offloaded non-selectively until the memory limit defined by | ||
| `cpu_offload_gb` is reached. | ||
| Examples: | ||
| - For parameter name "mlp.experts.w2_weight": | ||
| - "experts" or "experts.w2_weight" will match. | ||
| - "expert" or "w2" will NOT match (must be exact segments). | ||
| This allows distinguishing parameters like "w2_weight" and "w2_weight_scale". | ||
| """ | ||
|
|
||
|
|
||
| @config | ||
| class PrefetchOffloadConfig: | ||
| """Configuration for prefetch-based CPU offloading. | ||
|
|
||
| Groups layers and uses async H2D prefetch to hide transfer latency. | ||
| """ | ||
|
|
||
| offload_group_size: int = Field(default=0, ge=0) | ||
| """Group every N layers together. Offload last `offload_num_in_group` | ||
| layers of each group. Default is 0 (disabled). | ||
| Example: group_size=8, num_in_group=2 offloads layers 6,7,14,15,22,23,... | ||
| Unlike cpu_offload_gb, this uses explicit async prefetching to hide transfer | ||
| latency. | ||
| """ | ||
|
|
||
| offload_num_in_group: int = Field(default=1, ge=1) | ||
| """Number of layers to offload per group. | ||
| Must be <= offload_group_size. Default is 1.""" | ||
|
|
||
| offload_prefetch_step: int = Field(default=1, ge=0) | ||
| """Number of layers to prefetch ahead. | ||
| Higher values hide more latency but use more GPU memory. Default is 1.""" | ||
|
|
||
| offload_params: set[str] = Field(default_factory=set) | ||
| """The set of parameter name segments to target for prefetch offloading. | ||
| Unmatched parameters are not offloaded. If this set is empty, ALL | ||
| parameters of each offloaded layer are offloaded. | ||
| Uses segment matching: "w13_weight" matches "mlp.experts.w13_weight" | ||
| but not "mlp.experts.w13_weight_scale". | ||
| """ | ||
|
|
||
|
|
||
| @config | ||
| class OffloadConfig: | ||
| """Configuration for model weight offloading to reduce GPU memory usage.""" | ||
|
|
||
| offload_backend: OffloadBackend = "auto" | ||
| """The backend for weight offloading. Options: | ||
| - "auto": Selects based on which sub-config has non-default values | ||
| (prefetch if offload_group_size > 0, uva if cpu_offload_gb > 0). | ||
| - "uva": UVA (Unified Virtual Addressing) zero-copy offloading. | ||
| - "prefetch": Async prefetch with group-based layer offloading. | ||
| """ | ||
|
|
||
| uva: UVAOffloadConfig = Field(default_factory=UVAOffloadConfig) | ||
| """Parameters for UVA offloading backend.""" | ||
|
|
||
| prefetch: PrefetchOffloadConfig = Field(default_factory=PrefetchOffloadConfig) | ||
| """Parameters for prefetch offloading backend.""" | ||
|
|
||
| @model_validator(mode="after") | ||
| def validate_offload_config(self) -> "OffloadConfig": | ||
| """Validate offload configuration constraints.""" | ||
| if self.offload_backend == "prefetch" or self.prefetch.offload_group_size > 0: | ||
| if self.prefetch.offload_num_in_group > self.prefetch.offload_group_size: | ||
| raise ValueError( | ||
| f"offload_num_in_group ({self.prefetch.offload_num_in_group})" | ||
| f" must be <= offload_group_size" | ||
| f" ({self.prefetch.offload_group_size})" | ||
| ) | ||
| if self.prefetch.offload_prefetch_step < 1: | ||
| raise ValueError( | ||
| f"offload_prefetch_step" | ||
| f" ({self.prefetch.offload_prefetch_step})" | ||
| f" must be >= 1 when prefetch offloading is enabled" | ||
| f" (offload_group_size > 0)" | ||
| ) | ||
|
|
||
| # Warn if both backends have non-default values | ||
| uva_active = self.uva.cpu_offload_gb > 0 | ||
| prefetch_active = self.prefetch.offload_group_size > 0 | ||
| if self.offload_backend == "uva" and prefetch_active: | ||
| warnings.warn( | ||
| "Prefetch offload fields are set but offload_backend='uva'. " | ||
| "Prefetch settings will be ignored.", | ||
| stacklevel=2, | ||
| ) | ||
| elif self.offload_backend == "prefetch" and uva_active: | ||
| warnings.warn( | ||
| "UVA offload fields are set but offload_backend='prefetch'. " | ||
| "UVA settings will be ignored.", | ||
| stacklevel=2, | ||
| ) | ||
| elif self.offload_backend == "auto" and uva_active and prefetch_active: | ||
| warnings.warn( | ||
| "Both UVA and prefetch offload fields are set with " | ||
| "offload_backend='auto'. Prefetch backend will be selected. " | ||
| "Set offload_backend explicitly to suppress this warning.", | ||
| stacklevel=2, | ||
| ) | ||
| return self | ||
|
minosfuture marked this conversation as resolved.
|
||
|
|
||
| def compute_hash(self) -> str: | ||
| """ | ||
| Provide a hash that uniquely identifies all the offload configs. | ||
|
|
||
| All fields are included because PrefetchOffloader patches module | ||
| forwards and inserts custom ops (wait_prefetch, start_prefetch) | ||
| into the computation graph. Changing any offload setting can | ||
| alter which layers are hooked and how prefetch indices are | ||
| computed, so the compilation cache must distinguish them. | ||
| """ | ||
| from vllm.config.utils import get_hash_factors, hash_factors | ||
|
|
||
| factors = get_hash_factors(self, ignored_factors=set()) | ||
| hash_str = hash_factors(factors) | ||
| return hash_str | ||
Uh oh!
There was an error while loading. Please reload this page.