Skip to content

Commit 775df35

Browse files
committed
Add swap_blocks_batch op with batched async memcpy
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
1 parent 6792890 commit 775df35

7 files changed

Lines changed: 312 additions & 3 deletions

File tree

csrc/cache.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,40 @@ void swap_blocks(
11751175
return;
11761176
}
11771177

1178+
/**
1179+
* @brief Batch version of swap_blocks: copies N independent (src, dst, size)
1180+
* triples in a single call, amortising per-copy overhead.
1181+
*
1182+
* Thin wrapper that validates tensor inputs, sets the XPU device guard,
1183+
* and delegates to vllm::xpu::xpuAsyncMemcpyBatch for the actual copy logic.
1184+
*
1185+
* @param device XPU device index — required because all input tensors are
1186+
* CPU tensors so PyTorch cannot infer the target device.
1187+
*/
1188+
void swap_blocks_batch(
1189+
const torch::Tensor& src_ptrs,
1190+
const torch::Tensor& dst_ptrs,
1191+
const torch::Tensor& sizes) {
1192+
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
1193+
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
1194+
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
1195+
TORCH_CHECK(src_ptrs.dtype() == torch::kUInt64, "src_ptrs must be uint64");
1196+
TORCH_CHECK(dst_ptrs.dtype() == torch::kUInt64, "dst_ptrs must be uint64");
1197+
TORCH_CHECK(sizes.dtype() == torch::kUInt64, "sizes must be uint64");
1198+
1199+
const int64_t n = src_ptrs.size(0);
1200+
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
1201+
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
1202+
1203+
if (n == 0) return;
1204+
1205+
vllm::xpu::xpuAsyncMemcpyBatch(
1206+
src_ptrs.data_ptr<uint64_t>(),
1207+
dst_ptrs.data_ptr<uint64_t>(),
1208+
sizes.data_ptr<uint64_t>(),
1209+
n);
1210+
}
1211+
11781212
namespace vllm {
11791213

11801214
// Kernel for FP8 conversion (matches CUDA convert_fp8_kernel pattern).

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ void swap_blocks(
158158
int64_t block_size_in_bytes,
159159
const torch::Tensor& block_mapping);
160160

161+
void swap_blocks_batch(
162+
const torch::Tensor& src_ptrs,
163+
const torch::Tensor& dst_ptrs,
164+
const torch::Tensor& sizes);
165+
161166
void top_k_per_row_decode(
162167
const torch::Tensor& logits,
163168
int64_t next_n,

csrc/torch_bindings.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,11 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
206206
"swap_blocks(Tensor src, Tensor! dst,"
207207
" int block_size_in_bytes, Tensor block_mapping) -> ()");
208208
cache_ops.impl("swap_blocks", torch::kXPU, &swap_blocks);
209+
// Batch swap: copies N (src_ptr, dst_ptr, size) triples in one call.
210+
cache_ops.def(
211+
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs, Tensor sizes) -> "
212+
"()");
213+
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
209214
cache_ops.def(
210215
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache,"
211216
"Tensor slot_mapping, int quant_block_size, str scale_fmt) -> ()");

csrc/utils/mem_cpy.cpp

Lines changed: 81 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ inline void async_h2d_with_staging(
4747

4848
memcpy_async(queue, dst_device, staging_ptr, n_bytes);
4949

50-
// The staging buffer is managed by the allocator,
51-
// so record the event on it to ensure the staging buffer remains alive
52-
// until the DMA transfer completes.
50+
// Keep staging buffer alive until the DMA transfer completes.
5351
record_host_alloc_event_if_possible(staging_ptr, staging.get_context());
5452
}
5553

@@ -151,6 +149,86 @@ void xpuAsyncMemcpy(
151149
}
152150
}
153151

152+
void xpuAsyncMemcpyBatch(
153+
const uint64_t* src_ptrs,
154+
const uint64_t* dst_ptrs,
155+
const uint64_t* sizes,
156+
int64_t n) {
157+
if (n == 0) return;
158+
159+
auto& queue = vllm::xpu::vllmGetQueue();
160+
auto sycl_ctx = queue.get_context();
161+
162+
// Determine copy direction from the first non-zero entry's pointer types.
163+
// All entries in a batch are expected to share the same direction.
164+
// Staging is only needed for H2D: the source is host memory that the caller
165+
// may mutate after this function returns, while the DMA is still in flight.
166+
bool needs_staging = false;
167+
for (int64_t i = 0; i < n; i++) {
168+
if (sizes[i] > 0) {
169+
const void* first_src = reinterpret_cast<const void*>(src_ptrs[i]);
170+
const void* first_dst = reinterpret_cast<const void*>(dst_ptrs[i]);
171+
auto src_type = sycl::get_pointer_type(first_src, sycl_ctx);
172+
auto dst_type = sycl::get_pointer_type(first_dst, sycl_ctx);
173+
bool src_is_host =
174+
(src_type == sycl::usm::alloc::host ||
175+
src_type == sycl::usm::alloc::unknown);
176+
bool dst_is_device = (dst_type == sycl::usm::alloc::device);
177+
needs_staging = src_is_host && dst_is_device;
178+
break;
179+
}
180+
}
181+
182+
// Compute total bytes for staging allocation.
183+
uint64_t total_bytes = 0;
184+
for (int64_t i = 0; i < n; i++) {
185+
total_bytes += sizes[i];
186+
}
187+
if (total_bytes == 0) return;
188+
189+
if (needs_staging) {
190+
// H2D: allocate one contiguous pinned staging buffer, snapshot all source
191+
// blocks, then submit all async DMAs. This avoids N separate allocator
192+
// round-trips and protects against caller mutation after return.
193+
auto staging = at::getHostAllocator(at::kXPU)->allocate(
194+
static_cast<size_t>(total_bytes));
195+
char* staging_ptr = static_cast<char*>(staging.get());
196+
TORCH_CHECK(staging_ptr, "Failed to allocate pinned staging buffer");
197+
198+
size_t staging_offset = 0;
199+
for (int64_t i = 0; i < n; i++) {
200+
size_t sz = static_cast<size_t>(sizes[i]);
201+
if (sz == 0) continue;
202+
203+
const void* src = reinterpret_cast<const void*>(src_ptrs[i]);
204+
void* dst = reinterpret_cast<void*>(dst_ptrs[i]);
205+
206+
std::memcpy(staging_ptr + staging_offset, src, sz);
207+
queue.memcpy(dst, staging_ptr + staging_offset, sz);
208+
staging_offset += sz;
209+
}
210+
211+
// Keep the staging buffer alive until all submitted DMAs complete.
212+
if (staging.get_context() != nullptr) {
213+
at::getHostAllocator(at::kXPU)->record_event(
214+
staging_ptr,
215+
const_cast<void*>(staging.get_context()),
216+
at::xpu::getCurrentXPUStream());
217+
}
218+
} else {
219+
// D2H / D2D: direct async DMA, no staging needed.
220+
for (int64_t i = 0; i < n; i++) {
221+
size_t sz = static_cast<size_t>(sizes[i]);
222+
if (sz == 0) continue;
223+
224+
const void* src = reinterpret_cast<const void*>(src_ptrs[i]);
225+
void* dst = reinterpret_cast<void*>(dst_ptrs[i]);
226+
227+
queue.memcpy(dst, src, sz);
228+
}
229+
}
230+
}
231+
154232
} // namespace xpu
155233
} // namespace vllm
156234

csrc/utils/mem_cpy.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <cstddef>
3+
#include <cstdint>
34

45
namespace vllm {
56
namespace xpu {
@@ -32,5 +33,27 @@ void xpuAsyncMemcpy(
3233
const void* hctx,
3334
bool is_pinned);
3435

36+
/**
37+
* @brief Batch async memcpy: copies N independent (src, dst, size) triples
38+
* in a single call, amortising per-copy overhead.
39+
*
40+
* The copy direction is auto-detected from the first non-zero entry's USM
41+
* pointer types. All entries must share the same direction.
42+
*
43+
* For H2D: snapshots all source blocks through a single contiguous pinned
44+
* staging buffer so the caller may safely mutate host memory immediately.
45+
* For D2H / D2D: direct async DMA without staging.
46+
*
47+
* @param src_ptrs Array of N raw source addresses
48+
* @param dst_ptrs Array of N raw destination addresses
49+
* @param sizes Array of N byte counts
50+
* @param n Number of entries
51+
*/
52+
void xpuAsyncMemcpyBatch(
53+
const uint64_t* src_ptrs,
54+
const uint64_t* dst_ptrs,
55+
const uint64_t* sizes,
56+
int64_t n);
57+
3558
} // namespace xpu
3659
} // namespace vllm

tests/register_ops.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,16 @@ def swap_blocks(
487487
block_mapping)
488488

489489

490+
def swap_blocks_batch(
491+
src_ptrs: torch.Tensor,
492+
dst_ptrs: torch.Tensor,
493+
sizes: torch.Tensor,
494+
) -> None:
495+
"""Batch version of swap_blocks: copies N independent (src, dst, size)
496+
triples in a single call."""
497+
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
498+
499+
490500
def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
491501
token_expert_indices: torch.Tensor,
492502
gating_output: torch.Tensor, renormalize: bool,

tests/test_cache.py

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import random
44

5+
import numpy as np
56
import pytest
67
import torch
78

@@ -101,6 +102,13 @@
101102
"device": ["xpu:0"],
102103
"kv_cache_dtype": KV_CACHE_DTYPE,
103104
},
105+
"test_swap_blocks_batch": {
106+
"direction": [("cpu", "xpu")],
107+
"device": ["xpu:0"],
108+
},
109+
"test_swap_blocks_batch_h2d_mutation_race": {
110+
"device": ["xpu:0"],
111+
},
104112
}
105113

106114

@@ -948,3 +956,149 @@ def test_swap_blocks_mla(
948956
msg=f"Block {src} from src should have been swapped to block "
949957
f"{dst} in dst_cache.",
950958
)
959+
960+
961+
# ---------------------------------------------------------------------------
962+
# swap_blocks_batch tests
963+
# ---------------------------------------------------------------------------
964+
965+
966+
def _build_batch_args(
967+
src_cache: torch.Tensor,
968+
dst_cache: torch.Tensor,
969+
block_mapping: list[tuple[int, int]],
970+
block_size_in_bytes: int,
971+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
972+
"""Build (src_ptrs, dst_ptrs, sizes) tensors for swap_blocks_batch."""
973+
n = len(block_mapping)
974+
src_arr = np.empty(n, dtype=np.uint64)
975+
dst_arr = np.empty(n, dtype=np.uint64)
976+
sz_arr = np.full(n, block_size_in_bytes, dtype=np.uint64)
977+
978+
src_base = src_cache.data_ptr()
979+
dst_base = dst_cache.data_ptr()
980+
stride = src_cache.stride(0) * src_cache.element_size()
981+
982+
for i, (sb, db) in enumerate(block_mapping):
983+
src_arr[i] = src_base + sb * stride
984+
dst_arr[i] = dst_base + db * stride
985+
986+
return (torch.from_numpy(src_arr), torch.from_numpy(dst_arr),
987+
torch.from_numpy(sz_arr))
988+
989+
990+
@pytest.mark.parametrize("direction", COPYING_DIRECTION)
991+
@pytest.mark.parametrize("device", DEVICES)
992+
@torch.inference_mode()
993+
def test_swap_blocks_batch(
994+
direction: tuple[str, str],
995+
device: str,
996+
) -> None:
997+
"""Test swap_blocks_batch for H2D, D2H and D2D directions."""
998+
num_mappings = 64
999+
num_heads = 8
1000+
head_size = 64
1001+
block_size = 8
1002+
num_blocks = 256
1003+
dtype = torch.bfloat16
1004+
seed = 0
1005+
1006+
seed_everything(seed)
1007+
1008+
src_device = device if direction[0] == "xpu" else "cpu"
1009+
dst_device = device if direction[1] == "xpu" else "cpu"
1010+
1011+
src_blocks = random.sample(range(num_blocks), num_mappings)
1012+
if src_device == dst_device:
1013+
remaining = list(set(range(num_blocks)) - set(src_blocks))
1014+
dst_blocks = random.sample(remaining, num_mappings)
1015+
else:
1016+
dst_blocks = random.sample(range(num_blocks), num_mappings)
1017+
block_mapping = list(zip(src_blocks, dst_blocks))
1018+
1019+
src_key, src_val = create_kv_caches_with_random(num_blocks, block_size, 1,
1020+
num_heads, head_size,
1021+
"auto", dtype, seed,
1022+
src_device)
1023+
dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1,
1024+
num_heads, head_size,
1025+
"auto", dtype, seed,
1026+
dst_device)
1027+
1028+
src_key_clone = src_key[0].clone()
1029+
src_val_clone = src_val[0].clone()
1030+
1031+
block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0)
1032+
1033+
# Build batch args and call
1034+
for src_cache, dst_cache in [(src_key[0], dst_key[0]),
1035+
(src_val[0], dst_val[0])]:
1036+
sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping,
1037+
block_size_in_bytes)
1038+
ops.swap_blocks_batch(sp, dp, sz)
1039+
1040+
torch.xpu.synchronize()
1041+
1042+
for sb, db in block_mapping:
1043+
torch.testing.assert_close(src_key_clone[sb].cpu(),
1044+
dst_key[0][db].cpu())
1045+
torch.testing.assert_close(src_val_clone[sb].cpu(),
1046+
dst_val[0][db].cpu())
1047+
1048+
1049+
@torch.inference_mode()
1050+
def test_swap_blocks_batch_h2d_mutation_race() -> None:
1051+
"""Verify staging buffer protects against caller mutation for H2D batch."""
1052+
num_mappings = 256
1053+
num_heads = 8
1054+
head_size = 128
1055+
block_size = 32
1056+
num_blocks = 512
1057+
dtype = torch.bfloat16
1058+
seed = 0
1059+
1060+
seed_everything(seed)
1061+
1062+
src_blocks = random.sample(range(num_blocks), num_mappings)
1063+
dst_blocks = random.sample(range(num_blocks), num_mappings)
1064+
block_mapping = list(zip(src_blocks, dst_blocks))
1065+
1066+
# Source: pinned CPU memory
1067+
src_key, src_val = create_kv_caches_with_pinned(num_blocks, block_size, 1,
1068+
num_heads, head_size,
1069+
"auto", dtype, seed, "cpu")
1070+
assert src_key[0].is_pinned()
1071+
1072+
# Destination: XPU
1073+
dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1,
1074+
num_heads, head_size,
1075+
"auto", dtype, seed)
1076+
1077+
src_key_clone = src_key[0].clone()
1078+
src_val_clone = src_val[0].clone()
1079+
1080+
block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0)
1081+
1082+
for src_cache, dst_cache in [(src_key[0], dst_key[0]),
1083+
(src_val[0], dst_val[0])]:
1084+
sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping,
1085+
block_size_in_bytes)
1086+
ops.swap_blocks_batch(sp, dp, sz)
1087+
1088+
# Immediately mutate source — should not affect destination.
1089+
src_key[0].fill_(0)
1090+
src_val[0].fill_(0)
1091+
1092+
torch.xpu.synchronize()
1093+
1094+
for sb, db in block_mapping:
1095+
torch.testing.assert_close(
1096+
src_key_clone[sb].cpu(),
1097+
dst_key[0][db].cpu(),
1098+
msg=f"Key block {sb}{db} corrupted by post-call mutation",
1099+
)
1100+
torch.testing.assert_close(
1101+
src_val_clone[sb].cpu(),
1102+
dst_val[0][db].cpu(),
1103+
msg=f"Value block {sb}{db} corrupted by post-call mutation",
1104+
)

0 commit comments

Comments
 (0)