|
| 1 | +# SPDX-License-Identifier: Apache-2.0 |
| 2 | +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project |
| 3 | +""" |
| 4 | +This example shows how to use FlexKV with vLLM for prefix caching. |
| 5 | +
|
| 6 | +FlexKV is a distributed KV Store and multi-level cache management system for |
| 7 | +ultra-large-scale LLM inference. |
| 8 | +
|
| 9 | +Requirements: |
| 10 | + - Install FlexKV (https://github.com/taco-project/FlexKV): |
| 11 | + 1. git clone git@github.com:taco-project/FlexKV.git |
| 12 | + 2. cd FlexKV && bash build.sh |
| 13 | + - Ensure FlexKV is compatible with your vLLM version. |
| 14 | +
|
| 15 | +Usage: |
| 16 | + 1. Run this script: |
| 17 | + python examples/offline_inference/prefix_caching_flexkv.py \ |
| 18 | + --model /path/to/your/model |
| 19 | +
|
| 20 | + 2. Arguments: |
| 21 | + --model Path or name of the model (required) |
| 22 | + --tp-size Tensor parallel size (default: 1) |
| 23 | + --gpu-memory-util GPU memory utilization (default: 0.4) |
| 24 | +
|
| 25 | + 3. The script will: |
| 26 | + - Create a FlexKV configuration file. |
| 27 | + - Set the FLEXKV_CONFIG_PATH environment variable. |
| 28 | + - Run vLLM with FlexKVConnectorV1 enabled. |
| 29 | + - Compare results between regular execution, vLLM's default prefix |
| 30 | + caching, and FlexKV. |
| 31 | +""" |
| 32 | + |
| 33 | +import argparse |
| 34 | +import json |
| 35 | +import os |
| 36 | +import time |
| 37 | + |
| 38 | +from vllm import LLM, SamplingParams |
| 39 | +from vllm.distributed import cleanup_dist_env_and_memory |
| 40 | + |
| 41 | +# NOTE: This is just a running example. For benchmarking purpose, |
| 42 | +# please see benchmarks/benchmark_prefix_caching.py |
| 43 | + |
| 44 | + |
| 45 | +def parse_args(): |
| 46 | + parser = argparse.ArgumentParser( |
| 47 | + description="Example of using FlexKV with vLLM for prefix caching." |
| 48 | + ) |
| 49 | + parser.add_argument( |
| 50 | + "--model", |
| 51 | + type=str, |
| 52 | + required=True, |
| 53 | + help="Path or name of the model to use.", |
| 54 | + ) |
| 55 | + parser.add_argument( |
| 56 | + "--tp-size", |
| 57 | + type=int, |
| 58 | + default=1, |
| 59 | + help="Tensor parallel size (default: 1).", |
| 60 | + ) |
| 61 | + parser.add_argument( |
| 62 | + "--gpu-memory-util", |
| 63 | + type=float, |
| 64 | + default=0.4, |
| 65 | + help="GPU memory utilization fraction (default: 0.4).", |
| 66 | + ) |
| 67 | + return parser.parse_args() |
| 68 | + |
| 69 | + |
| 70 | +def main(): |
| 71 | + args = parse_args() |
| 72 | + |
| 73 | + flexkv_config = { |
| 74 | + "server_recv_port": f"ipc:///tmp/flexkv_test_{os.getpid()}", |
| 75 | + "cache_config": { |
| 76 | + "enable_cpu": True, |
| 77 | + "num_cpu_blocks": 10240, |
| 78 | + }, |
| 79 | + "num_log_interval_requests": 200, |
| 80 | + } |
| 81 | + flexkv_config_path = f"./flexkv_config_{os.getpid()}.json" |
| 82 | + with open(flexkv_config_path, "w") as f: |
| 83 | + json.dump(flexkv_config, f) |
| 84 | + os.environ["FLEXKV_CONFIG_PATH"] = flexkv_config_path |
| 85 | + |
| 86 | + try: |
| 87 | + _run(args) |
| 88 | + finally: |
| 89 | + if os.path.exists(flexkv_config_path): |
| 90 | + os.remove(flexkv_config_path) |
| 91 | + |
| 92 | + |
| 93 | +def _run(args): |
| 94 | + # Common prefix. |
| 95 | + prefix = ( |
| 96 | + "You are an expert school principal, skilled in effectively managing " |
| 97 | + "faculty and staff. Draft 10-15 questions for a potential first grade " |
| 98 | + "Head Teacher for my K-12, all-girls', independent school that emphasizes " |
| 99 | + "community, joyful discovery, and life-long learning. The candidate is " |
| 100 | + "coming in for a first-round panel interview for a 8th grade Math " |
| 101 | + "teaching role. They have 5 years of previous teaching experience " |
| 102 | + "as an assistant teacher at a co-ed, public school with experience " |
| 103 | + "in middle school math teaching. Based on these information, fulfill " |
| 104 | + "the following paragraph: " |
| 105 | + ) |
| 106 | + |
| 107 | + # Sample prompts. |
| 108 | + prompts = [ |
| 109 | + "Hello, my name is", |
| 110 | + "The president of the United States is", |
| 111 | + "The capital of France is", |
| 112 | + "The future of AI is", |
| 113 | + ] |
| 114 | + |
| 115 | + generating_prompts = [prefix + prompt for prompt in prompts] |
| 116 | + |
| 117 | + # Create a sampling params object. |
| 118 | + sampling_params = SamplingParams(temperature=0.0) |
| 119 | + |
| 120 | + kv_transfer_config = { |
| 121 | + "kv_connector": "FlexKVConnectorV1", |
| 122 | + "kv_role": "kv_both", |
| 123 | + } |
| 124 | + |
| 125 | + # Create an LLM without prefix caching as a baseline. |
| 126 | + regular_llm = LLM( |
| 127 | + model=args.model, |
| 128 | + enable_prefix_caching=False, |
| 129 | + gpu_memory_utilization=args.gpu_memory_util, |
| 130 | + tensor_parallel_size=args.tp_size, |
| 131 | + ) |
| 132 | + |
| 133 | + print("Results without `enable_prefix_caching`") |
| 134 | + |
| 135 | + # ruff: noqa: E501 |
| 136 | + # Generate texts from the prompts. The output is a list of RequestOutput |
| 137 | + # objects that contain the prompt, generated text, and other information. |
| 138 | + outputs = regular_llm.generate(generating_prompts, sampling_params) |
| 139 | + |
| 140 | + regular_generated_texts = [] |
| 141 | + # Print the outputs. |
| 142 | + print("-" * 50) |
| 143 | + for output in outputs: |
| 144 | + prompt = output.prompt |
| 145 | + generated_text = output.outputs[0].text |
| 146 | + regular_generated_texts.append(generated_text) |
| 147 | + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") |
| 148 | + print("-" * 50) |
| 149 | + |
| 150 | + # Destroy the LLM object and free up the GPU memory. |
| 151 | + del regular_llm |
| 152 | + cleanup_dist_env_and_memory() |
| 153 | + |
| 154 | + # Create an LLM with prefix caching enabled. |
| 155 | + prefix_cached_llm = LLM( |
| 156 | + model=args.model, |
| 157 | + enable_prefix_caching=True, |
| 158 | + gpu_memory_utilization=args.gpu_memory_util, |
| 159 | + tensor_parallel_size=args.tp_size, |
| 160 | + kv_transfer_config=kv_transfer_config, |
| 161 | + ) |
| 162 | + |
| 163 | + # Warmup so that the shared prompt's KV cache is computed. |
| 164 | + prefix_cached_llm.generate(generating_prompts[0], sampling_params) |
| 165 | + |
| 166 | + # wait for offload kv task finished. |
| 167 | + time.sleep(2) |
| 168 | + |
| 169 | + # Generate with prefix caching. |
| 170 | + outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) |
| 171 | + |
| 172 | + print("Results with `enable_prefix_caching`") |
| 173 | + |
| 174 | + cached_generated_texts = [] |
| 175 | + # Print the outputs. You should see the same outputs as before. |
| 176 | + print("-" * 50) |
| 177 | + for output in outputs: |
| 178 | + prompt = output.prompt |
| 179 | + generated_text = output.outputs[0].text |
| 180 | + cached_generated_texts.append(generated_text) |
| 181 | + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") |
| 182 | + print("-" * 50) |
| 183 | + |
| 184 | + # Compare the results and display the speedup |
| 185 | + generated_same = all( |
| 186 | + regular_generated_texts[i] == cached_generated_texts[i] |
| 187 | + for i in range(len(prompts)) |
| 188 | + ) |
| 189 | + print(f"Generated answers are the same: {generated_same}") |
| 190 | + |
| 191 | + # wait for offload kv task finished. |
| 192 | + time.sleep(2) |
| 193 | + |
| 194 | + # reset prefix cache to use flexkv |
| 195 | + prefix_cached_llm.reset_prefix_cache() |
| 196 | + |
| 197 | + # Generate with prefix caching. |
| 198 | + outputs = prefix_cached_llm.generate(generating_prompts, sampling_params) |
| 199 | + |
| 200 | + print("Results with `flexkv`") |
| 201 | + |
| 202 | + flexkv_generated_texts = [] |
| 203 | + # Print the outputs. You should see the same outputs as before. |
| 204 | + print("-" * 50) |
| 205 | + for output in outputs: |
| 206 | + prompt = output.prompt |
| 207 | + generated_text = output.outputs[0].text |
| 208 | + flexkv_generated_texts.append(generated_text) |
| 209 | + print(f"Prompt: {prompt!r}\nGenerated text: {generated_text!r}") |
| 210 | + print("-" * 50) |
| 211 | + |
| 212 | + # Compare the results and display the speedup |
| 213 | + generated_same = all( |
| 214 | + regular_generated_texts[i] == flexkv_generated_texts[i] |
| 215 | + for i in range(len(prompts)) |
| 216 | + ) |
| 217 | + print(f"Generated answers are the same: {generated_same}") |
| 218 | + |
| 219 | + |
| 220 | +if __name__ == "__main__": |
| 221 | + main() |
0 commit comments