[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync#38460
[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync#38460orozery merged 32 commits intovllm-project:mainfrom
Conversation
Replace per-layer per-block swap_blocks calls with a single batched swap_blocks_batch call that submits all copies in one driver invocation. The existing offloading path issues L×N individual cudaMemcpyAsync calls (layers × block pairs), each incurring ~2μs of CPU submission overhead. For typical configurations (80 layers × 32 blocks = 2,560 calls), this overhead dominates the actual transfer time. swap_blocks_batch collects all source/destination pointers and sizes into flat arrays, then: - On CUDA 12.8+: submits them via cuMemcpyBatchAsync (one driver call) - On older CUDA/ROCm: falls back to a loop of cudaMemcpyAsync with cudaMemcpyDefault (same behavior as before, no regression) The Python side pre-computes base pointers and block sizes at init time, then builds the flat pointer arrays using vectorized numpy arithmetic. Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
There was a problem hiding this comment.
Code Review
This pull request introduces swap_blocks_batch, a batched memory copy operation designed to reduce overhead during KV cache offloading. The implementation leverages cuMemcpyBatchAsync on supported CUDA versions (12.8+) and provides a fallback for older environments. The CPUGPUWorker is updated to aggregate transfer requests into a single batch call. A review comment suggests that the operation should be registered for the CUDA device in the Torch bindings instead of CPU.
| cache_ops.def( | ||
| "swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs," | ||
| " Tensor sizes) -> ()"); | ||
| cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch); |
There was a problem hiding this comment.
The swap_blocks_batch operation is implemented in csrc/cache_kernels.cu and involves CUDA operations. However, it's being registered here for the torch::kCPU device. This is likely incorrect as the implementation relies on CUDA streams and memory copies. It should be registered for torch::kCUDA to ensure it's dispatched correctly when called on CUDA tensors.
| cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch); | |
| cache_ops.impl("swap_blocks_batch", torch::kCUDA, &swap_blocks_batch); |
There was a problem hiding this comment.
The input tensors (src_ptrs, dst_ptrs, sizes) are CPU tensors — they're numpy arrays of raw pointers/sizes converted via torch.from_numpy(). PyTorch dispatches based on the input tensor device, so kCPU is correct here. The existing swap_blocks uses kCUDA because its inputs are the actual GPU KV cache tensors. Registering with kCUDA would actually break dispatch since no input tensor lives on GPU.
| static_assert(sizeof(size_t) == sizeof(int64_t)); | ||
| #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080 | ||
| CUmemcpyAttributes attr = {}; | ||
| attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM; |
There was a problem hiding this comment.
Minor comment: Curious have you tried CU_MEMCPY_SRC_ACCESS_ORDER_ANY (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f)? I found it gives me better CPU->GPU bandwidth on Grace Blackwell nodes.
There was a problem hiding this comment.
I see you also applied this parameter to GPU srcs.
According to the documentation this means access to srcs can be out of stream, so potentially not waiting for the compute (default) stream to complete?
@Etelis Anyhow for CPU->GPU this seems safe. Let's test it towards a follow up.
Thanks @ivanium for this suggestion!
There was a problem hiding this comment.
Right, it won't wait for previous ops in the stream. Since we typically call this API in a separate copy stream, I guess we cannot rely on this AccessOrder param anyway. If we want to stay safe as a general purpose API, maybe we can expose a configurable param to users.
There was a problem hiding this comment.
I'll test it as a followup
|
Yes the |
It is indeed disabled for older CUDA versions and ROCm. I think his issue comes from having venv with misconfigured CUDA version? |
The venv only has CUDA 12.8 runtime but it still has to use the system CUDA 12.1 driver. I think a more reliable way is to check the CUDA driver version rather than the CUDA runtime version. |
|
Right now we have a compile-time check (using ifdef). |
|
FYI after this merged I can no longer build vLLM from main on my DGX Spark as I could before this. The error: |
|
@orozery, having the same issue as @bbrowning after this PR on DGX Spark (TORCH_CUDA_ARCH_LIST=12.1a): @mgoin, @johnnynunez - FYI |
|
@bbrowning @eugr This CUDA 13.0 compilation error is caused by a signature change in I've opened #38919 which fixes both this issue and the |
…t#38460) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
…t#38460) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Rishi Puri <riship@nvidia.com>
…t#38460) Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com>
### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
…-project#7819) ### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
…-project#7819) ### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
…-project#7819) ### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: guxin108 <1252896542@qq.com>
…-project#7819) ### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com> Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
…-project#7819) ### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Replace per-layer per-block
swap_blockscalls in the KV cache offloadinghandler with a single
swap_blocks_batchcall that submits all copies inone driver invocation.
On CUDA 12.8+ this uses
cuMemcpyBatchAsync; on older CUDA/ROCm it fallsback to a flat
cudaMemcpyAsyncloop withcudaMemcpyDefault. Zero extraGPU memory. No behavior change.
Supersedes #38216 (rebased clean).
Benchmark Results
Hardware: 8xH100 80GB HBM3, CUDA 12.8/12.9.
Baseline runs the original files via
pip install vllm==0.18.0 --force-reinstall --no-deps. Each mode ran in afresh Python process to avoid stale module state.
Handler-level benchmark
Directly measures the transfer path in isolation — no model inference.
Setup: Instantiate
CpuGpuOffloadingHandlerswith BF16 GPU tensors shapedby
FlashAttentionBackend.get_kv_cache_shape()using each model's realarchitecture.
Call
handler.transfer_async(), pollget_finished()until complete. Measurewith
time.perf_counter()(includes CUDA sync). 5 warmup + 100 measurediterations. Baseline and batched ran as separate processes.
Per-layer tensors (FlashAttn K/V split, no cross-layer allocation):
E2E vLLM serve — KV transfer bandwidth
Setup:
vllm servewith each real model, offloading enabled:python -m vllm.entrypoints.openai.api_server \ --model <model> --gpu-memory-utilization <0.4-0.5> \ --kv-transfer-config '{\"kv_connector\":\"OffloadingConnector\", \"kv_role\":\"kv_both\", \"kv_connector_extra_config\":{\"cpu_bytes_to_use\":2147483648, \"block_size\":48}}' \ --max-model-len 4096 --enforce-eagertotal_timeis measured by CUDA events (start_event.elapsed_time(end_event))inside
SingleDirectionOffloadingHandler.get_finished()Bandwidth =total_bytes / total_time.3. E2E serving throughput
Qwen2.5-7B shows 18% throughput improvement and p99 TTFT reduction from
6 seconds to 117ms under concurrent load. This I guess it because the model has 4 KV heads
(vs 8 for LLaMA/Mistral), producing smaller per-block copies where
submission overhead dominates.