Skip to content

Commit 5e3ebeb

Browse files
feiqiangsphaedonsun
authored andcommitted
[KV Connector] Support using FlexKV as KV Cache Offloading option. (vllm-project#34328)
Signed-off-by: phaedonsun <phaedonsun@tencent.com> Co-authored-by: phaedonsun <phaedonsun@tencent.com>
1 parent 9d537f4 commit 5e3ebeb

5 files changed

Lines changed: 725 additions & 0 deletions

File tree

docs/features/disagg_prefill.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ For NixlConnector, you may also specify one or multiple NIXL_Backend. Such as:
4444
--kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 64, "cpu_bytes_to_use": 1000000000}}'
4545
```
4646

47+
- **FlexKVConnectorV1**: refer to [examples/offline_inference/prefix_caching_flexkv.py](../../examples/offline_inference/prefix_caching_flexkv.py) for the example usage of FlexKVConnectorV1. FlexKV is a distributed KV Store and multi-level cache management system for ultra-large-scale LLM inference.
48+
49+
```bash
50+
--kv-transfer-config '{"kv_connector":"FlexKVConnectorV1","kv_role":"kv_both"}'
51+
```
52+
4753
## Benchmarks
4854

4955
Please refer to [benchmarks/disagg_benchmarks](../../benchmarks/disagg_benchmarks) for disaggregated prefilling benchmarks.
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
"""
4+
This example shows how to use FlexKV with vLLM for prefix caching.
5+
6+
FlexKV is a distributed KV Store and multi-level cache management system for
7+
ultra-large-scale LLM inference.
8+
9+
Requirements:
10+
- Install FlexKV (https://github.com/taco-project/FlexKV):
11+
1. git clone git@github.com:taco-project/FlexKV.git
12+
2. cd FlexKV && bash build.sh
13+
- Ensure FlexKV is compatible with your vLLM version.
14+
15+
Usage:
16+
1. Run this script:
17+
python examples/offline_inference/prefix_caching_flexkv.py \
18+
--model /path/to/your/model
19+
20+
2. Arguments:
21+
--model Path or name of the model (required)
22+
--tp-size Tensor parallel size (default: 1)
23+
--gpu-memory-util GPU memory utilization (default: 0.4)
24+
25+
3. The script will:
26+
- Create a FlexKV configuration file.
27+
- Set the FLEXKV_CONFIG_PATH environment variable.
28+
- Run vLLM with FlexKVConnectorV1 enabled.
29+
- Compare results between regular execution, vLLM's default prefix
30+
caching, and FlexKV.
31+
"""
32+
33+
import argparse
34+
import json
35+
import os
36+
import time
37+
38+
from vllm import LLM, SamplingParams
39+
from vllm.distributed import cleanup_dist_env_and_memory
40+
41+
# NOTE: This is just a running example. For benchmarking purpose,
42+
# please see benchmarks/benchmark_prefix_caching.py
43+
44+
45+
def parse_args():
46+
parser = argparse.ArgumentParser(
47+
description="Example of using FlexKV with vLLM for prefix caching."
48+
)
49+
parser.add_argument(
50+
"--model",
51+
type=str,
52+
required=True,
53+
help="Path or name of the model to use.",
54+
)
55+
parser.add_argument(
56+
"--tp-size",
57+
type=int,
58+
default=1,
59+
help="Tensor parallel size (default: 1).",
60+
)
61+
parser.add_argument(
62+
"--gpu-memory-util",
63+
type=float,
64+
default=0.4,
65+
help="GPU memory utilization fraction (default: 0.4).",
66+
)
67+
return parser.parse_args()
68+
69+
70+
def main():
71+
args = parse_args()
72+
73+
flexkv_config = {
74+
"server_recv_port": f"ipc:///tmp/flexkv_test_{os.getpid()}",
75+
"cache_config": {
76+
"enable_cpu": True,
77+
"num_cpu_blocks": 10240,
78+
},
79+
"num_log_interval_requests": 200,
80+
}
81+
flexkv_config_path = f"./flexkv_config_{os.getpid()}.json"
82+
with open(flexkv_config_path, "w") as f:
83+
json.dump(flexkv_config, f)
84+
os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path
85+
86+
try:
87+
_run(args)
88+
finally:
89+
if os.path.exists(flexkv_config_path):
90+
os.remove(flexkv_config_path)
91+
92+
93+
def _run(args):
94+
# Common prefix.
95+
prefix = (
96+
"You are an expert school principal, skilled in effectively managing "
97+
"faculty and staff. Draft 10-15 questions for a potential first grade "
98+
"Head Teacher for my K-12, all-girls', independent school that emphasizes "
99+
"community, joyful discovery, and life-long learning. The candidate is "
100+
"coming in for a first-round panel interview for a 8th grade Math "
101+
"teaching role. They have 5 years of previous teaching experience "
102+
"as an assistant teacher at a co-ed, public school with experience "
103+
"in middle school math teaching. Based on these information, fulfill "
104+
"the following paragraph: "
105+
)
106+
107+
# Sample prompts.
108+
prompts = [
109+
"Hello, my name is",
110+
"The president of the United States is",
111+
"The capital of France is",
112+
"The future of AI is",
113+
]
114+
115+
generating_prompts = [prefix + prompt for prompt in prompts]
116+
117+
# Create a sampling params object.
118+
sampling_params = SamplingParams(temperature=0.0)
119+
120+
kv_transfer_config = {
121+
"kv_connector": "FlexKVConnectorV1",
122+
"kv_role": "kv_both",
123+
}
124+
125+
# Create an LLM without prefix caching as a baseline.
126+
regular_llm = LLM(
127+
model=args.model,
128+
enable_prefix_caching=False,
129+
gpu_memory_utilization=args.gpu_memory_util,
130+
tensor_parallel_size=args.tp_size,
131+
)
132+
133+
print("Results without `enable_prefix_caching`")
134+
135+
# ruff: noqa: E501
136+
# Generate texts from the prompts. The output is a list of RequestOutput
137+
# objects that contain the prompt, generated text, and other information.
138+
outputs = regular_llm.generate(generating_prompts, sampling_params)
139+
140+
regular_generated_texts = []
141+
# Print the outputs.
142+
print("-" * 50)
143+
for output in outputs:
144+
prompt = output.prompt
145+
generated_text = output.outputs[0].text
146+
regular_generated_texts.append(generated_text)
147+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
148+
print("-" * 50)
149+
150+
# Destroy the LLM object and free up the GPU memory.
151+
del regular_llm
152+
cleanup_dist_env_and_memory()
153+
154+
# Create an LLM with prefix caching enabled.
155+
prefix_cached_llm = LLM(
156+
model=args.model,
157+
enable_prefix_caching=True,
158+
gpu_memory_utilization=args.gpu_memory_util,
159+
tensor_parallel_size=args.tp_size,
160+
kv_transfer_config=kv_transfer_config,
161+
)
162+
163+
# Warmup so that the shared prompt's KV cache is computed.
164+
prefix_cached_llm.generate(generating_prompts[0], sampling_params)
165+
166+
# wait for offload kv task finished.
167+
time.sleep(2)
168+
169+
# Generate with prefix caching.
170+
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
171+
172+
print("Results with `enable_prefix_caching`")
173+
174+
cached_generated_texts = []
175+
# Print the outputs. You should see the same outputs as before.
176+
print("-" * 50)
177+
for output in outputs:
178+
prompt = output.prompt
179+
generated_text = output.outputs[0].text
180+
cached_generated_texts.append(generated_text)
181+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
182+
print("-" * 50)
183+
184+
# Compare the results and display the speedup
185+
generated_same = all(
186+
regular_generated_texts[i] == cached_generated_texts[i]
187+
for i in range(len(prompts))
188+
)
189+
print(f"Generated answers are the same: {generated_same}")
190+
191+
# wait for offload kv task finished.
192+
time.sleep(2)
193+
194+
# reset prefix cache to use flexkv
195+
prefix_cached_llm.reset_prefix_cache()
196+
197+
# Generate with prefix caching.
198+
outputs = prefix_cached_llm.generate(generating_prompts, sampling_params)
199+
200+
print("Results with `flexkv`")
201+
202+
flexkv_generated_texts = []
203+
# Print the outputs. You should see the same outputs as before.
204+
print("-" * 50)
205+
for output in outputs:
206+
prompt = output.prompt
207+
generated_text = output.outputs[0].text
208+
flexkv_generated_texts.append(generated_text)
209+
print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}")
210+
print("-" * 50)
211+
212+
# Compare the results and display the speedup
213+
generated_same = all(
214+
regular_generated_texts[i] == flexkv_generated_texts[i]
215+
for i in range(len(prompts))
216+
)
217+
print(f"Generated answers are the same: {generated_same}")
218+
219+
220+
if __name__ == "__main__":
221+
main()

0 commit comments

Comments
 (0)