Skip to content

Commit b2a0268

Browse files
HF-00101267596wangxiyuan
authored
[Performance]Batch kvcache offloading via aclrtMemcpyBatchAsync (vllm-project#7819)
### What this PR does / why we need it? refer to vllm-project/vllm#38460 and vllm-project/vllm#38915 , cann 8.5.0+ use aclrtMemcpyBatchAsync, old cann version use aclrtMemcpyAsync to do kvcache offloading. It can automatically compile and select the appropriate transmission function based on the CANN environment, and also supports manual parameter transmission to choose the suitable transmission function. manual parameter : 1. batch memcpy(need CANN ≥ 8.5): export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 pip install -e . 2. normal memcpy: export VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 pip install -e . ### How was this patch tested? test results: main : TTFT 307 ms TPOT 49.96ms this pr : TTFT 272.82ms TPOT 41.04ms model script: export TP=1 export MODEL_PATH=/nas/disk1/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10113 export CUDA_VISIBLE_DEVICES=3 export ASCEND_RT_VISIBLE_DEVICES=3 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --no-enable-prefix-caching --max-model-len 32768 --trust-remote-code \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' test script: export MODEL_NAME=/nas/disk1/Qwen3-14B python /model/xk/vllm/benchmarks/multi_turn/benchmark_serving_multi_turn.py --url http://127.0.0.1:10113 --model $MODEL_NAME --served-model-name Qwen3-14B --seed 1234 --input-file /model/xk/vllm/benchmarks/multi_turn/generate_multi_turn.json \ --num-clients 8 --max-active-conversations 24 - vLLM version: v0.18.0 - vLLM main: vllm-project/vllm@35141a7 --------- Signed-off-by: 01267596 <xiongkai123@cmbchina.com> Signed-off-by: HF-001 <1670186653@qq.com> Signed-off-by: kx <1670186653@qq.com> Co-authored-by: 01267596 <xiongkai123@cmbchina.com> Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
1 parent 50da4c7 commit b2a0268

5 files changed

Lines changed: 306 additions & 49 deletions

File tree

CMakeLists.txt

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,47 @@ set(
111111

112112
pybind11_add_module(vllm_ascend_C ${VLLM_ASCEND_SRC})
113113

114+
# Detect aclrtMemcpyBatchAsync availability (CANN 8.5+)
115+
# Can be overridden via VLLM_ASCEND_ENABLE_BATCH_MEMCPY env var (registered
116+
# in vllm_ascend/envs.py, forwarded by setup.py as a CMake variable):
117+
# VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1 -> force enable
118+
# VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0 -> force disable
119+
# unset -> auto-detect from CANN headers
120+
include(CheckCXXSourceCompiles)
121+
set(CMAKE_REQUIRED_INCLUDES ${ASCEND_HOME_PATH}/include)
122+
set(CMAKE_REQUIRED_LIBRARIES ascendcl)
123+
set(CMAKE_REQUIRED_LINK_OPTIONS "-L${ASCEND_HOME_PATH}/lib64")
124+
125+
if(DEFINED VLLM_ASCEND_ENABLE_BATCH_MEMCPY)
126+
if("${VLLM_ASCEND_ENABLE_BATCH_MEMCPY}" STREQUAL "1")
127+
message(STATUS "aclrtMemcpyBatchAsync: force enabled via VLLM_ASCEND_ENABLE_BATCH_MEMCPY=1")
128+
target_compile_definitions(vllm_ascend_C PRIVATE CANN_MEMCPY_BATCH_ASYNC)
129+
else()
130+
message(STATUS "aclrtMemcpyBatchAsync: force disabled via VLLM_ASCEND_ENABLE_BATCH_MEMCPY=0")
131+
endif()
132+
else()
133+
# Test the full code pattern we actually use, including struct member access.
134+
# This ensures the macro is only defined when the API is fully compatible.
135+
check_cxx_source_compiles("
136+
#include <acl/acl_rt.h>
137+
int main() {
138+
aclrtMemLocation loc = {};
139+
loc.type = ACL_MEM_LOCATION_TYPE_HOST;
140+
loc.id = 0;
141+
aclrtMemcpyBatchAttr attr = {};
142+
attr.srcLoc = loc;
143+
attr.dstLoc = loc;
144+
(void)aclrtMemcpyBatchAsync;
145+
return 0;
146+
}
147+
" HAVE_ACLRT_MEMCPY_BATCH_ASYNC)
148+
if(HAVE_ACLRT_MEMCPY_BATCH_ASYNC)
149+
message(STATUS "aclrtMemcpyBatchAsync: detected in CANN headers, enabling batch memcpy path")
150+
target_compile_definitions(vllm_ascend_C PRIVATE CANN_MEMCPY_BATCH_ASYNC)
151+
else()
152+
message(STATUS "aclrtMemcpyBatchAsync: not found in CANN headers, using fallback aclrtMemcpyAsync loop")
153+
endif()
154+
endif()
114155
# Prefer the CANN ACL headers over torch_npu's bundled third_party ACL copy.
115156
# torch_npu 2.9.0 ships an older acl_rt.h that does not declare
116157
# aclrtLaunchHostFunc, which breaks host-print compilation.

csrc/torch_binding.cpp

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,128 @@ void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z)
168168
return;
169169
}
170170

171+
void swap_blocks_batch(const torch::Tensor& src_ptrs,
172+
const torch::Tensor& dst_ptrs,
173+
const torch::Tensor& sizes,
174+
int64_t direction) {
175+
176+
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
177+
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
178+
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
179+
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
180+
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
181+
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
182+
183+
const int64_t n = src_ptrs.size(0);
184+
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
185+
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
186+
187+
if (n == 0) return;
188+
189+
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
190+
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
191+
const int64_t* size_data = sizes.data_ptr<int64_t>();
192+
193+
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();
194+
195+
aclrtMemcpyKind memcpy_kind;
196+
switch (direction) {
197+
case 0:
198+
memcpy_kind = ACL_MEMCPY_HOST_TO_DEVICE;
199+
break;
200+
case 1:
201+
memcpy_kind = ACL_MEMCPY_DEVICE_TO_HOST;
202+
break;
203+
case 2:
204+
memcpy_kind = ACL_MEMCPY_DEVICE_TO_DEVICE;
205+
break;
206+
default:
207+
TORCH_CHECK(false,
208+
"swap_blocks_batch: invalid direction ", direction,
209+
" (expected 0=H2D, 1=D2H, 2=D2D)");
210+
}
211+
212+
// =========================================================================
213+
// path 1: aclrtMemcpyBatchAsync (CANN 8.5+)
214+
// =========================================================================
215+
#if defined(CANN_MEMCPY_BATCH_ASYNC)
216+
if (memcpy_kind != ACL_MEMCPY_DEVICE_TO_DEVICE) {
217+
static_assert(sizeof(void*) == sizeof(int64_t),
218+
"void* and int64_t must be the same size");
219+
static_assert(sizeof(size_t) == sizeof(int64_t),
220+
"size_t and int64_t must be the same size");
221+
222+
void** dst_arr = reinterpret_cast<void**>(
223+
const_cast<int64_t*>(dst_data));
224+
void** src_arr = reinterpret_cast<void**>(
225+
const_cast<int64_t*>(src_data));
226+
size_t* size_arr = reinterpret_cast<size_t*>(
227+
const_cast<int64_t*>(size_data));
228+
size_t* dest_maxs = size_arr;
229+
230+
// aclrtMemcpyBatchAttr uses srcLoc/dstLoc (aclrtMemLocation)
231+
// to specify memory locations, not aclrtMemcpyKind.
232+
int32_t device_id = 0;
233+
aclrtGetDevice(&device_id);
234+
235+
aclrtMemLocation host_loc = {};
236+
host_loc.type = ACL_MEM_LOCATION_TYPE_HOST;
237+
host_loc.id = 0;
238+
239+
aclrtMemLocation device_loc = {};
240+
device_loc.type = ACL_MEM_LOCATION_TYPE_DEVICE;
241+
device_loc.id = device_id;
242+
243+
aclrtMemcpyBatchAttr attr = {};
244+
if (memcpy_kind == ACL_MEMCPY_HOST_TO_DEVICE) {
245+
attr.srcLoc = host_loc;
246+
attr.dstLoc = device_loc;
247+
} else { // ACL_MEMCPY_DEVICE_TO_HOST
248+
attr.srcLoc = device_loc;
249+
attr.dstLoc = host_loc;
250+
}
251+
252+
size_t attrs_index = 0;
253+
size_t fail_index = 0;
254+
255+
aclError result = aclrtMemcpyBatchAsync(
256+
dst_arr, dest_maxs, src_arr, size_arr,
257+
static_cast<size_t>(n),
258+
&attr, &attrs_index, 1,
259+
&fail_index, stream);
260+
261+
TORCH_CHECK(result == ACL_SUCCESS,
262+
"aclrtMemcpyBatchAsync failed at index ", fail_index,
263+
" with error code ", result);
264+
return;
265+
}
266+
#endif
267+
268+
// =========================================================================
269+
// path 2: aclrtMemcpyAsync
270+
// =========================================================================
271+
for (int64_t i = 0; i < n; i++) {
272+
void* dst = reinterpret_cast<void*>(dst_data[i]);
273+
const void* src = reinterpret_cast<const void*>(src_data[i]);
274+
size_t copy_size = static_cast<size_t>(size_data[i]);
275+
276+
aclError ret = aclrtMemcpyAsync(
277+
dst,
278+
copy_size,
279+
src,
280+
copy_size,
281+
memcpy_kind,
282+
stream);
283+
284+
TORCH_CHECK(ret == ACL_SUCCESS,
285+
"aclrtMemcpyAsync failed at index ", i,
286+
" with error code ", ret,
287+
", src=", src_data[i],
288+
", dst=", dst_data[i],
289+
", size=", size_data[i]);
290+
}
291+
}
292+
171293
AscendType get_dtype_from_torch(at::ScalarType scalarType)
172294
{
173295
if (scalarType == at::ScalarType::Float) {
@@ -962,6 +1084,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops)
9621084
ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()");
9631085
ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks);
9641086

1087+
// swap_blocks_batch takes CPU tensors (int64 pointer/size arrays), not NPU
1088+
// tensors, so dispatch must be registered on the CPU backend. The function
1089+
// internally submits async memcpy on the current NPU stream.
1090+
ops.def("swap_blocks_batch(Tensor x, Tensor y, Tensor z, int direction) -> ()");
1091+
ops.impl("swap_blocks_batch", torch::kCPU, &vllm_ascend::swap_blocks_batch);
9651092
ops.def("device_print(str msg) -> ()");
9661093
ops.impl("device_print", c10::DispatchKey::CompositeExplicitAutograd,
9671094
static_cast<void (*)(c10::string_view)>(&vllm_ascend::device_print));

setup.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -338,6 +338,11 @@ def configure(self, ext: CMakeExtension) -> None:
338338
# add TORCH_NPU_PATH
339339
cmake_args += [f"-DTORCH_NPU_PATH={torch_npu_path}"]
340340

341+
# Pass VLLM_ASCEND_ENABLE_BATCH_MEMCPY to CMake if explicitly set.
342+
# When unset (None), CMake will auto-detect from CANN headers.
343+
if envs.VLLM_ASCEND_ENABLE_BATCH_MEMCPY is not None:
344+
cmake_args += [f"-DVLLM_ASCEND_ENABLE_BATCH_MEMCPY={envs.VLLM_ASCEND_ENABLE_BATCH_MEMCPY}"]
345+
341346
build_tool = []
342347
# TODO(ganyi): ninja and ccache support for ascend c auto codegen. now we can only use make build
343348
# if which('ninja') is not None:

vllm_ascend/envs.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,9 @@
111111
"VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK": lambda: bool(
112112
int(os.getenv("VLLM_ASCEND_FUSION_OP_TRANSPOSE_KV_CACHE_BY_BLOCK", "1"))
113113
),
114+
# Control the aclrtMemcpyBatchAsync compile path for KV cache offloading.
115+
# "1": force enable, "0": force disable, None: auto-detect from CANN headers.
116+
"VLLM_ASCEND_ENABLE_BATCH_MEMCPY": lambda: os.getenv("VLLM_ASCEND_ENABLE_BATCH_MEMCPY", None),
114117
}
115118

116119
# end-env-vars-definition

0 commit comments

Comments
 (0)