66
77import vllm .envs as envs
88from vllm .attention import get_attn_backend
9- from vllm .config import (CacheConfig , CompilationConfig , DeviceConfig ,
10- ModelConfig , ParallelConfig , VllmConfig )
9+ from vllm .config import (CacheConfig , DeviceConfig , ModelConfig ,
10+ ParallelConfig , VllmConfig )
1111from vllm .distributed import (ensure_model_parallel_initialized ,
1212 init_distributed_environment )
1313from vllm .logger import init_logger
@@ -33,8 +33,8 @@ class CPUCacheEngine:
3333 """
3434
3535 def __init__ (self , cache_config : CacheConfig , model_config : ModelConfig ,
36- parallel_config : ParallelConfig , device_config : DeviceConfig ,
37- compilation_config : CompilationConfig ) -> None :
36+ parallel_config : ParallelConfig ,
37+ device_config : DeviceConfig ) -> None :
3838 assert device_config .device_type == "cpu"
3939 self .cache_config = cache_config
4040 self .model_config = model_config
@@ -66,8 +66,6 @@ def __init__(self, cache_config: CacheConfig, model_config: ModelConfig,
6666
6767 # Initialize the cache.
6868 self .cpu_cache = self ._allocate_kv_cache (self .num_cpu_blocks )
69- bind_kv_cache (compilation_config .static_forward_context ,
70- self .cpu_cache )
7169
7270 def _allocate_kv_cache (
7371 self ,
@@ -292,13 +290,15 @@ def _init_cache_engine(self) -> None:
292290 self .model_config ,
293291 self .parallel_config ,
294292 self .device_config ,
295- self .compilation_config ,
296293 ) for _ in range (self .parallel_config .pipeline_parallel_size )
297294 ]
298295 self .cpu_cache = [
299296 self .cache_engine [ve ].cpu_cache
300297 for ve in range (self .parallel_config .pipeline_parallel_size )
301298 ]
299+ for ve in range (self .parallel_config .pipeline_parallel_size ):
300+ bind_kv_cache (self .compilation_config .static_forward_context ,
301+ self .cpu_cache [ve ], ve )
302302 self .model_runner .block_size = self .cache_engine [0 ].block_size
303303
304304 assert all (
0 commit comments