Skip to content

[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync#38460

Merged
orozery merged 32 commits intovllm-project:mainfrom
Etelis:swap-blocks-batch-v2
Apr 3, 2026
Merged

[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync#38460
orozery merged 32 commits intovllm-project:mainfrom
Etelis:swap-blocks-batch-v2

Conversation

@Etelis
Copy link
Copy Markdown
Contributor

@Etelis Etelis commented Mar 29, 2026

Replace per-layer per-block swap_blocks calls in the KV cache offloading
handler with a single swap_blocks_batch call that submits all copies in
one driver invocation.

On CUDA 12.8+ this uses cuMemcpyBatchAsync; on older CUDA/ROCm it falls
back to a flat cudaMemcpyAsync loop with cudaMemcpyDefault. Zero extra
GPU memory. No behavior change.

Supersedes #38216 (rebased clean).

Benchmark Results

Hardware: 8xH100 80GB HBM3, CUDA 12.8/12.9.

Baseline runs the original files via
pip install vllm==0.18.0 --force-reinstall --no-deps. Each mode ran in a
fresh Python process to avoid stale module state.

Handler-level benchmark

Directly measures the transfer path in isolation — no model inference.

Setup: Instantiate CpuGpuOffloadingHandlers with BF16 GPU tensors shaped
by FlashAttentionBackend.get_kv_cache_shape() using each model's real
architecture.
Call handler.transfer_async(), poll get_finished() until complete. Measure
with time.perf_counter() (includes CUDA sync). 5 warmup + 100 measured
iterations. Baseline and batched ran as separate processes.

Config Layers KV Heads Head Dim Tensors Per-tensor block size
LLaMA-8B 32 8 128 64 32 KB
LLaMA-70B 80 8 128 160 32 KB
LLaMA-70B TP=4 80 2 128 160 8 KB
Qwen2.5-0.5B 24 2 64 48 4 KB
Qwen2.5-1.5B 28 2 128 56 8 KB
Qwen2.5-3B 36 2 128 72 8 KB
Phi-3-mini 32 8 96 64 24 KB

Per-layer tensors (FlashAttn K/V split, no cross-layer allocation):

Config Tensors Baseline Batched Speedup
LLaMA-8B, 64 blocks 64 16,358 us 4,519 us 3.6x
LLaMA-70B, 64 blocks 160 41,124 us 10,239 us 4.0x
LLaMA-70B TP=4, 32 blocks 160 20,427 us 3,299 us 6.2x
Qwen2.5-0.5B, 128 blocks 48 23,484 us 3,184 us 7.4x
Qwen2.5-1.5B, 64 blocks 56 14,042 us 2,272 us 6.2x
Qwen2.5-3B, 64 blocks 72 18,045 us 2,935 us 6.2x
Phi-3-mini, 64 blocks 64 15,963 us 3,888 us 4.1x

E2E vLLM serve — KV transfer bandwidth

Setup:

  1. Start vllm serve with each real model, offloading enabled:
    python -m vllm.entrypoints.openai.api_server \
        --model <model> --gpu-memory-utilization <0.4-0.5> \
        --kv-transfer-config '{\"kv_connector\":\"OffloadingConnector\",
            \"kv_role\":\"kv_both\",
            \"kv_connector_extra_config\":{\"cpu_bytes_to_use\":2147483648,
            \"block_size\":48}}' \
        --max-model-len 4096 --enforce-eager
  2. Send sustained load: 8 concurrent Python threads, each sending prompts in a loop for 45 seconds.
  3. Read the KV transfer bandwidth from vLLM's server log
    total_time is measured by CUDA events (start_event.elapsed_time(end_event))
    inside SingleDirectionOffloadingHandler.get_finished() Bandwidth = total_bytes / total_time.
Model KV Heads Baseline BW Batched BW Improvement
LLaMA-3.2-1B 8 27.9 GB/s 32.4 GB/s +16%
Qwen2.5-1.5B 2 25.8 GB/s 31.8 GB/s +23%
Qwen2.5-3B 2 29.5 GB/s 34.3 GB/s +16%
Gemma-2-2B 4 40.2 GB/s 43.3 GB/s +8%
Phi-3.5-mini 8 51.6 GB/s 53.3 GB/s +3%
Qwen2.5-7B 4 34.2 GB/s 40.7 GB/s +19%
LLaMA-3.1-8B 8 44.0 GB/s 47.0 GB/s +7%
Mistral-7B 8 45.1 GB/s 48.6 GB/s +8%

3. E2E serving throughput

Model Baseline req/s Batched req/s Baseline p99 TTFT Batched p99 TTFT
Qwen2.5-7B 6.0 7.1 (+18%) 6,077 ms 117 ms
LLaMA-3.1-8B 6.7 6.7 100.7 ms 105.7 ms
Mistral-7B 6.7 6.6 91.4 ms 82.5 ms

Qwen2.5-7B shows 18% throughput improvement and p99 TTFT reduction from
6 seconds to 117ms under concurrent load. This I guess it because the model has 4 KV heads
(vs 8 for LLaMA/Mistral), producing smaller per-block copies where
submission overhead dominates.

Replace per-layer per-block swap_blocks calls with a single batched
swap_blocks_batch call that submits all copies in one driver invocation.

The existing offloading path issues L×N individual cudaMemcpyAsync calls
(layers × block pairs), each incurring ~2μs of CPU submission overhead.
For typical configurations (80 layers × 32 blocks = 2,560 calls), this
overhead dominates the actual transfer time.

swap_blocks_batch collects all source/destination pointers and sizes into
flat arrays, then:
- On CUDA 12.8+: submits them via cuMemcpyBatchAsync (one driver call)
- On older CUDA/ROCm: falls back to a loop of cudaMemcpyAsync with
  cudaMemcpyDefault (same behavior as before, no regression)

The Python side pre-computes base pointers and block sizes at init time,
then builds the flat pointer arrays using vectorized numpy arithmetic.

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
@Etelis Etelis requested review from ApostaC and orozery as code owners March 29, 2026 09:58
Copy link
Copy Markdown

@claude claude Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Claude Code Review

This pull request is from a fork — automated review is disabled. A repository maintainer can comment @claude review to run a one-time review.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces swap_blocks_batch, a batched memory copy operation designed to reduce overhead during KV cache offloading. The implementation leverages cuMemcpyBatchAsync on supported CUDA versions (12.8+) and provides a fallback for older environments. The CPUGPUWorker is updated to aggregate transfer requests into a single batch call. A review comment suggests that the operation should be registered for the CUDA device in the Torch bindings instead of CPU.

Comment thread csrc/torch_bindings.cpp
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The swap_blocks_batch operation is implemented in csrc/cache_kernels.cu and involves CUDA operations. However, it's being registered here for the torch::kCPU device. This is likely incorrect as the implementation relies on CUDA streams and memory copies. It should be registered for torch::kCUDA to ensure it's dispatched correctly when called on CUDA tensors.

Suggested change
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
cache_ops.impl("swap_blocks_batch", torch::kCUDA, &swap_blocks_batch);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input tensors (src_ptrs, dst_ptrs, sizes) are CPU tensors — they're numpy arrays of raw pointers/sizes converted via torch.from_numpy(). PyTorch dispatches based on the input tensor device, so kCPU is correct here. The existing swap_blocks uses kCUDA because its inputs are the actual GPU KV cache tensors. Registering with kCUDA would actually break dispatch since no input tensor lives on GPU.

@orozery orozery added the ready ONLY add when PR is ready to merge/full CI is needed label Mar 29, 2026
@orozery orozery enabled auto-merge (squash) March 30, 2026 09:58
Comment thread csrc/cache_kernels.cu
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
Copy link
Copy Markdown
Contributor

@ivanium ivanium Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: Curious have you tried CU_MEMCPY_SRC_ACCESS_ORDER_ANY (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f)? I found it gives me better CPU->GPU bandwidth on Grace Blackwell nodes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you also applied this parameter to GPU srcs.
According to the documentation this means access to srcs can be out of stream, so potentially not waiting for the compute (default) stream to complete?

@Etelis Anyhow for CPU->GPU this seems safe. Let's test it towards a follow up.

Thanks @ivanium for this suggestion!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it won't wait for previous ops in the stream. Since we typically call this API in a separate copy stream, I guess we cannot rely on this AccessOrder param anyway. If we want to stay safe as a general purpose API, maybe we can expose a configurable param to users.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test it as a followup

@ivanium
Copy link
Copy Markdown
Contributor

ivanium commented Apr 3, 2026

Yes the cuMemcpyBatchAsync() API was introduced in CUDA 12.8, and we need to disable it for older CUDA driver versions.

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 3, 2026

Yes the cuMemcpyBatchAsync() API was introduced in CUDA 12.8, and we need to disable it for older CUDA driver versions.

It is indeed disabled for older CUDA versions and ROCm.

I think his issue comes from having venv with misconfigured CUDA version?

@ivanium
Copy link
Copy Markdown
Contributor

ivanium commented Apr 3, 2026

Yes the cuMemcpyBatchAsync() API was introduced in CUDA 12.8, and we need to disable it for older CUDA driver versions.

It is indeed disabled for older CUDA versions and ROCm.

I think his issue comes from having venv with misconfigured CUDA version?

The venv only has CUDA 12.8 runtime but it still has to use the system CUDA 12.1 driver. I think a more reliable way is to check the CUDA driver version rather than the CUDA runtime version.

@orozery
Copy link
Copy Markdown
Collaborator

orozery commented Apr 3, 2026

Right now we have a compile-time check (using ifdef).
If a user is using CUDA version older than 12.8 it will fail he's using an image built with CUDA 12.8+ (which is the case for pre-built wheels, currently using 12.9+).
We can use cuGetProcAddress to check if the user actually has the batched API available, but I'm wondering if it's worth the added complexity.
The general question is whether vLLM aims to support CUDA older than 12.8 on its pre-built wheels.
Users can always compile on their system to get a compatible version.

@bbrowning
Copy link
Copy Markdown
Collaborator

FYI after this merged I can no longer build vLLM from main on my DGX Spark as I could before this.

NVIDIA-SMI 580.126.09             Driver Version: 580.126.09     CUDA Version: 13.0  

The error:

      /home/bbrowning/src/vllm/csrc/cache_kernels.cu -o CMakeFiles/_C.dir/csrc/cache_kernels.cu.o                                     
      /home/bbrowning/src/vllm/csrc/cache_kernels.cu(115): error: argument of type "size_t *" (aka "unsigned long *") is incompatible with parameter of type "CUstream" (aka "CUstream_st *")                                                                               
              static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,                                                 
                                                            ^                                                                         
                                                                                                                                      
      /home/bbrowning/src/vllm/csrc/cache_kernels.cu(116): error: too many arguments in function call
              static_cast<CUstream>(stream));                                                                                         
              ^                                                                                                                       

      2 errors detected in the compilation of "/home/bbrowning/src/vllm/csrc/cache_kernels.cu".

@eugr
Copy link
Copy Markdown

eugr commented Apr 3, 2026

@orozery, having the same issue as @bbrowning after this PR on DGX Spark (TORCH_CUDA_ARCH_LIST=12.1a):

/workspace/vllm/vllm/csrc/cache_kernels.cu(115): error: argument of type "size_t *" (aka "unsigned long *") is incompatible with parameter of type "CUstream" (aka "CUstream_st *")

@mgoin, @johnnynunez - FYI

@Etelis
Copy link
Copy Markdown
Contributor Author

Etelis commented Apr 3, 2026

@bbrowning @eugr This CUDA 13.0 compilation error is caused by a signature change in cuMemcpyBatchAsync — CUDA 13.0 removed the failIdx parameter (the unversioned symbol now maps to cuMemcpyBatchAsync_v2 with 8 params instead of 9).

I've opened #38919 which fixes both this issue and the undefined symbol crash on older drivers (reported by @JaheimLee above). The fix uses cuGetProcAddress to dynamically resolve the function at runtime with a function pointer — this bypasses the header #define remapping entirely, so it compiles and works correctly on CUDA 12.8, 13.0, and older drivers alike.

khluu added a commit that referenced this pull request Apr 3, 2026
HenryTangDev pushed a commit to HenryTangMain/vllm that referenced this pull request Apr 6, 2026
…t#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
puririshi98 pushed a commit to puririshi98/vllm that referenced this pull request Apr 7, 2026
…t#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
Signed-off-by: Rishi Puri <riship@nvidia.com>
mtparet pushed a commit to blackfuel-ai/vllm that referenced this pull request Apr 9, 2026
…t#38460)

Signed-off-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Itay Etelis <itay.etelis@ibm.com>
Co-authored-by: Or Ozeri <oro@il.ibm.com>
wangxiyuan added a commit to vllm-project/vllm-ascend that referenced this pull request Apr 21, 2026
### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
weijinqian0 pushed a commit to weijinqian0/vllm-ascend that referenced this pull request Apr 21, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
anning-2026 pushed a commit to anning-2026/vllm-ascend that referenced this pull request Apr 21, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
guxin108 pushed a commit to guxin108/vllm-ascend that referenced this pull request Apr 24, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: guxin108 <1252896542@qq.com>
zouyida2052 pushed a commit to zouyida2052/vllm-ascend that referenced this pull request Apr 28, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24

- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
yangzhe-2026 pushed a commit to yangzhe-2026/vllm-ascend that referenced this pull request May 6, 2026
…-project#7819)

### What this PR does / why we need it?
refer to vllm-project/vllm#38460 and
vllm-project/vllm#38915 , cann 8.5.0+ use
aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do
kvcache offloading.

It can automatically compile and select the appropriate transmission
function based on the CANN environment, and also supports manual
parameter transmission to choose the suitable transmission function.
manual parameter :
1. batch memcpy(need CANN ≥ 8.5): export
VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e .
2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install
-e .

### How was this patch tested?

test results:
main :    TTFT 307 ms         TPOT 49.96ms
this pr :  TTFT 272.82ms    TPOT 41.04ms

model script:
export TP=1
export MODEL_PATH=/nas/disk1/Qwen3-14B
export MODEL_NAME=Qwen3-14B
export PORT=10113
export CUDA_VISIBLE_DEVICES=3
export ASCEND_RT_VISIBLE_DEVICES=3
python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port
${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name
${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7
--no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \
    --block-size 128 \
--kv-transfer-config
'{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size":
128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec",
"spec_module_path": "vllm_ascend.kv_offload.npu"}}'

test script:
export MODEL_NAME=/nas/disk1/Qwen3-14B
python
/model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py
--url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name
Qwen3-14B --seed 1234 --input-file
/model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \
--num-clients 8 --max-active-conversations 24



- vLLM version: v0.18.0
- vLLM main:
vllm-project/vllm@35141a7

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Signed-off-by: kx <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants