Skip to content

Commit 16ae632

Browse files
EtelisEtelisIBMorozery
authored andcommitted
[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync (vllm-project#38460)
Signed-off-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Itay Etelis <itay.etelis@ibm.com> Co-authored-by: Or Ozeri <oro@il.ibm.com> Signed-off-by: Rishi Puri <riship@nvidia.com>
1 parent 54114b5 commit 16ae632

5 files changed

Lines changed: 118 additions & 15 deletions

File tree

csrc/cache.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
1010
int64_t block_size_in_bytes,
1111
const torch::Tensor& block_mapping);
1212

13+
void swap_blocks_batch(const torch::Tensor& src_ptrs,
14+
const torch::Tensor& dst_ptrs,
15+
const torch::Tensor& sizes);
16+
1317
void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
1418
torch::Tensor& key_cache, torch::Tensor& value_cache,
1519
torch::Tensor& slot_mapping,

csrc/cache_kernels.cu

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#ifdef USE_ROCM
2525
#include <hip/hip_bf16.h>
2626
typedef __hip_bfloat16 __nv_bfloat16;
27+
#else
28+
#include <cuda.h>
2729
#endif
2830

2931
#if defined(__gfx942__)
@@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
7375
}
7476
}
7577

78+
void swap_blocks_batch(const torch::Tensor& src_ptrs,
79+
const torch::Tensor& dst_ptrs,
80+
const torch::Tensor& sizes) {
81+
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
82+
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
83+
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
84+
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
85+
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
86+
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");
87+
88+
const int64_t n = src_ptrs.size(0);
89+
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
90+
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
91+
92+
if (n == 0) return;
93+
94+
const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
95+
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
96+
const int64_t* size_data = sizes.data_ptr<int64_t>();
97+
98+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
99+
100+
// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
101+
// driver call, amortizing per-copy submission overhead.
102+
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
103+
// so we reinterpret_cast the tensor data directly to avoid copies.
104+
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
105+
static_assert(sizeof(size_t) == sizeof(int64_t));
106+
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
107+
CUmemcpyAttributes attr = {};
108+
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
109+
size_t attrs_idx = 0;
110+
size_t fail_idx = 0;
111+
CUresult result = cuMemcpyBatchAsync(
112+
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
113+
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
114+
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
115+
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
116+
static_cast<CUstream>(stream));
117+
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
118+
fail_idx, " with error ", result);
119+
#else
120+
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
121+
// cudaMemcpyDefault lets the driver infer direction from pointer types.
122+
for (int64_t i = 0; i < n; i++) {
123+
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
124+
reinterpret_cast<void*>(src_data[i]),
125+
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
126+
stream);
127+
}
128+
#endif
129+
}
130+
76131
namespace vllm {
77132

78133
// Grid: (num_layers, num_pairs)

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
508508
" int block_size_in_bytes, Tensor block_mapping) -> ()");
509509
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);
510510

511+
// Batch swap: submit all block copies in a single driver call.
512+
cache_ops.def(
513+
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
514+
" Tensor sizes) -> ()");
515+
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
516+
511517
// Reshape the key and value tensors and cache them.
512518
cache_ops.def(
513519
"reshape_and_cache(Tensor key, Tensor value,"

vllm/_custom_ops.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2641,6 +2641,22 @@ def swap_blocks(
26412641
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)
26422642

26432643

2644+
def swap_blocks_batch(
2645+
src_ptrs: torch.Tensor,
2646+
dst_ptrs: torch.Tensor,
2647+
sizes: torch.Tensor,
2648+
) -> None:
2649+
"""
2650+
Batch version of swap_blocks: submit all copies in a single driver call.
2651+
2652+
Each entry specifies a raw pointer copy: src_ptrs[i] -> dst_ptrs[i]
2653+
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
2654+
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
2655+
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
2656+
"""
2657+
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
2658+
2659+
26442660
def convert_fp8(
26452661
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
26462662
) -> None:

vllm/v1/kv_offload/worker/cpu_gpu.py

Lines changed: 37 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,17 @@ def __init__(
149149
# list of CUDA events available for re-use
150150
self._event_pool: list[torch.Event] = []
151151

152+
# Pre-compute base pointers and block sizes for batch copies.
153+
self._src_base_ptrs = np.array(
154+
[t.data_ptr() for t in self.src_tensors], dtype=np.int64
155+
)
156+
self._dst_base_ptrs = np.array(
157+
[t.data_ptr() for t in self.dst_tensors], dtype=np.int64
158+
)
159+
self._block_size_in_bytes_arr = np.array(
160+
self.tensor_block_size_in_bytes, dtype=np.int64
161+
)
162+
152163
def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
153164
src_spec, dst_spec = transfer_spec
154165
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
@@ -165,15 +176,35 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
165176

166177
assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip
167178

168-
src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
179+
src_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
180+
dst_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
169181
expand_block_ids(
170182
src_blocks,
171183
self.src_block_size_factor,
172-
src_to_dst[:, 0],
184+
src_block_ids,
173185
skip_count=src_sub_blocks_to_skip,
174186
)
175-
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
176-
src_to_dst_tensor = torch.from_numpy(src_to_dst)
187+
expand_block_ids(dst_blocks, self.dst_block_size_factor, dst_block_ids)
188+
189+
# Build flat pointer arrays for all tensors × all block pairs.
190+
num_pairs = dst_sub_block_count
191+
num_tensors = len(self.src_tensors)
192+
total = num_pairs * num_tensors
193+
194+
all_src = np.empty(total, dtype=np.int64)
195+
all_dst = np.empty(total, dtype=np.int64)
196+
all_sizes = np.empty(total, dtype=np.int64)
197+
198+
for t_idx, bsz in enumerate(self._block_size_in_bytes_arr):
199+
start = t_idx * num_pairs
200+
end = start + num_pairs
201+
all_src[start:end] = self._src_base_ptrs[t_idx] + src_block_ids * bsz
202+
all_dst[start:end] = self._dst_base_ptrs[t_idx] + dst_block_ids * bsz
203+
all_sizes[start:end] = bsz
204+
205+
batch_src = torch.from_numpy(all_src)
206+
batch_dst = torch.from_numpy(all_dst)
207+
batch_sizes = torch.from_numpy(all_sizes)
177208

178209
stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
179210
start_event = (
@@ -197,17 +228,8 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
197228
stream.wait_event(last_event)
198229
with torch.cuda.stream(stream):
199230
start_event.record(stream)
200-
for src_tensor, dst_tensor, block_size_in_bytes in zip(
201-
self.src_tensors,
202-
self.dst_tensors,
203-
self.tensor_block_size_in_bytes,
204-
):
205-
ops.swap_blocks(
206-
src_tensor,
207-
dst_tensor,
208-
block_size_in_bytes,
209-
src_to_dst_tensor,
210-
)
231+
if total > 0:
232+
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
211233
end_event.record(stream)
212234

213235
self._transfer_events[job_id] = end_event

0 commit comments

Comments
 (0)