Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
d9f57cb
[Perf] Batch KV cache swap copies via cuMemcpyBatchAsync
EtelisIBM Mar 25, 2026
bb8c84c
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 29, 2026
48e1981
Merge branch 'main' into swap-blocks-batch-v2
orozery Mar 30, 2026
43371f7
Merge branch 'main' into swap-blocks-batch-v2
orozery Mar 30, 2026
cfd1990
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 30, 2026
f4e6fd3
Merge branch 'main' into swap-blocks-batch-v2
orozery Mar 30, 2026
01cf9a8
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 30, 2026
ce30ed0
Merge branch 'main' into swap-blocks-batch-v2
orozery Mar 31, 2026
0733407
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
6926c12
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
1ac17a2
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
f238c88
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
d42ace5
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
24f9716
Merge branch 'main' into swap-blocks-batch-v2
Etelis Mar 31, 2026
47b1d6c
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
b161b1a
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
2712093
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
602cdc1
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
a78f1e2
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
5c15e76
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
10714d8
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
d88add5
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 1, 2026
735aefa
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 1, 2026
895816b
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 1, 2026
99bba24
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 2, 2026
7415a98
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 2, 2026
9347eca
Retry CI
EtelisIBM Apr 2, 2026
f695c1e
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 2, 2026
6e90ee6
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 2, 2026
7d25c74
Merge branch 'main' into swap-blocks-batch-v2
orozery Apr 2, 2026
c27f38f
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 2, 2026
15ab2fe
Merge remote-tracking branch 'upstream/main' into swap-blocks-batch-v2
EtelisIBM Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions csrc/cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);

void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);

void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
Expand Down
55 changes: 55 additions & 0 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifdef USE_ROCM
#include <hip/hip_bf16.h>
typedef __hip_bfloat16 __nv_bfloat16;
#else
#include <cuda.h>
#endif

#if defined(__gfx942__)
Expand Down Expand Up @@ -73,6 +75,59 @@ void swap_blocks(torch::Tensor& src, torch::Tensor& dst,
}
}

void swap_blocks_batch(const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kInt64, "src_ptrs must be int64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kInt64, "dst_ptrs must be int64");
TORCH_CHECK(sizes.dtype() == torch::kInt64, "sizes must be int64");

const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");

if (n == 0) return;

const int64_t* src_data = src_ptrs.data_ptr<int64_t>();
const int64_t* dst_data = dst_ptrs.data_ptr<int64_t>();
const int64_t* size_data = sizes.data_ptr<int64_t>();

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
Copy link
Copy Markdown
Contributor

@ivanium ivanium Apr 1, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: Curious have you tried CU_MEMCPY_SRC_ACCESS_ORDER_ANY (https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f)? I found it gives me better CPU->GPU bandwidth on Grace Blackwell nodes.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see you also applied this parameter to GPU srcs.
According to the documentation this means access to srcs can be out of stream, so potentially not waiting for the compute (default) stream to complete?

@Etelis Anyhow for CPU->GPU this seems safe. Let's test it towards a follow up.

Thanks @ivanium for this suggestion!

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, it won't wait for previous ops in the stream. Since we typically call this API in a separate copy stream, I guess we cannot rely on this AccessOrder param anyway. If we want to stay safe as a general purpose API, maybe we can expose a configurable param to users.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll test it as a followup

size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(dst_data)),
reinterpret_cast<CUdeviceptr*>(const_cast<int64_t*>(src_data)),
reinterpret_cast<size_t*>(const_cast<int64_t*>(size_data)),
static_cast<size_t>(n), &attr, &attrs_idx, 1, &fail_idx,
static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
#endif
}

namespace vllm {

// Grid: (num_layers, num_pairs)
Expand Down
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kCUDA, &swap_blocks);

// Batch swap: submit all block copies in a single driver call.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The swap_blocks_batch operation is implemented in csrc/cache_kernels.cu and involves CUDA operations. However, it's being registered here for the torch::kCPU device. This is likely incorrect as the implementation relies on CUDA streams and memory copies. It should be registered for torch::kCUDA to ensure it's dispatched correctly when called on CUDA tensors.

Suggested change
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
cache_ops.impl("swap_blocks_batch", torch::kCUDA, &swap_blocks_batch);

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The input tensors (src_ptrs, dst_ptrs, sizes) are CPU tensors — they're numpy arrays of raw pointers/sizes converted via torch.from_numpy(). PyTorch dispatches based on the input tensor device, so kCPU is correct here. The existing swap_blocks uses kCUDA because its inputs are the actual GPU KV cache tensors. Registering with kCUDA would actually break dispatch since no input tensor lives on GPU.


// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache(Tensor key, Tensor value,"
Expand Down
16 changes: 16 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2595,6 +2595,22 @@ def swap_blocks(
torch.ops._C_cache_ops.swap_blocks(src, dst, block_size_in_bytes, block_mapping)


def swap_blocks_batch(
src_ptrs: torch.Tensor,
dst_ptrs: torch.Tensor,
sizes: torch.Tensor,
) -> None:
"""
Batch version of swap_blocks: submit all copies in a single driver call.

Each entry specifies a raw pointer copy: src_ptrs[i] -> dst_ptrs[i]
of sizes[i] bytes. All three tensors must be int64 CPU tensors.
On CUDA 12.8+ this uses cuMemcpyBatchAsync for minimal submission
overhead; on older CUDA it falls back to a loop of cudaMemcpyAsync.
"""
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)


def convert_fp8(
output: torch.Tensor, input: torch.Tensor, scale: float = 1.0, kv_dtype: str = "fp8"
) -> None:
Expand Down
52 changes: 37 additions & 15 deletions vllm/v1/kv_offload/worker/cpu_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,17 @@ def __init__(
# list of CUDA events available for re-use
self._event_pool: list[torch.Event] = []

# Pre-compute base pointers and block sizes for batch copies.
self._src_base_ptrs = np.array(
[t.data_ptr() for t in self.src_tensors], dtype=np.int64
)
self._dst_base_ptrs = np.array(
[t.data_ptr() for t in self.dst_tensors], dtype=np.int64
)
self._block_size_in_bytes_arr = np.array(
self.tensor_block_size_in_bytes, dtype=np.int64
)

def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
src_spec, dst_spec = transfer_spec
assert isinstance(src_spec, BlockIDsLoadStoreSpec)
Expand All @@ -165,15 +176,35 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:

assert dst_sub_block_count == src_sub_block_count - src_sub_blocks_to_skip

src_to_dst = np.empty((dst_sub_block_count, 2), dtype=np.int64)
src_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
dst_block_ids = np.empty(dst_sub_block_count, dtype=np.int64)
expand_block_ids(
src_blocks,
self.src_block_size_factor,
src_to_dst[:, 0],
src_block_ids,
skip_count=src_sub_blocks_to_skip,
)
expand_block_ids(dst_blocks, self.dst_block_size_factor, src_to_dst[:, 1])
src_to_dst_tensor = torch.from_numpy(src_to_dst)
expand_block_ids(dst_blocks, self.dst_block_size_factor, dst_block_ids)

# Build flat pointer arrays for all tensors × all block pairs.
num_pairs = dst_sub_block_count
num_tensors = len(self.src_tensors)
total = num_pairs * num_tensors

all_src = np.empty(total, dtype=np.int64)
all_dst = np.empty(total, dtype=np.int64)
all_sizes = np.empty(total, dtype=np.int64)

for t_idx, bsz in enumerate(self._block_size_in_bytes_arr):
start = t_idx * num_pairs
end = start + num_pairs
all_src[start:end] = self._src_base_ptrs[t_idx] + src_block_ids * bsz
all_dst[start:end] = self._dst_base_ptrs[t_idx] + dst_block_ids * bsz
all_sizes[start:end] = bsz

batch_src = torch.from_numpy(all_src)
batch_dst = torch.from_numpy(all_dst)
batch_sizes = torch.from_numpy(all_sizes)

stream = self._stream_pool.pop() if self._stream_pool else torch.cuda.Stream()
start_event = (
Expand All @@ -197,17 +228,8 @@ def transfer_async(self, job_id: int, transfer_spec: TransferSpec) -> bool:
stream.wait_event(last_event)
with torch.cuda.stream(stream):
start_event.record(stream)
for src_tensor, dst_tensor, block_size_in_bytes in zip(
self.src_tensors,
self.dst_tensors,
self.tensor_block_size_in_bytes,
):
ops.swap_blocks(
src_tensor,
dst_tensor,
block_size_in_bytes,
src_to_dst_tensor,
)
if total > 0:
ops.swap_blocks_batch(batch_src, batch_dst, batch_sizes)
end_event.record(stream)

self._transfer_events[job_id] = end_event
Expand Down
Loading