Skip to content

Commit 56ab101

Browse files
Add swap_blocks_batch op with batched async memcpy (#265)
Signed-off-by: chaojun-zhang <chaojun.zhang@intel.com>
1 parent d9dc454 commit 56ab101

7 files changed

Lines changed: 366 additions & 0 deletions

File tree

csrc/cache.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1175,6 +1175,37 @@ void swap_blocks(
11751175
return;
11761176
}
11771177

1178+
/**
1179+
* @brief Batch version of swap_blocks: copies N independent (src, dst, size)
1180+
* triples in a single call, amortising per-copy overhead.
1181+
*
1182+
* Thin wrapper that validates the CPU tensor inputs and delegates to
1183+
* vllm::xpu::xpuAsyncMemcpyBatch for the actual copy logic.
1184+
*/
1185+
void swap_blocks_batch(
1186+
const torch::Tensor& src_ptrs,
1187+
const torch::Tensor& dst_ptrs,
1188+
const torch::Tensor& sizes) {
1189+
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
1190+
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
1191+
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
1192+
TORCH_CHECK(src_ptrs.dtype() == torch::kUInt64, "src_ptrs must be uint64");
1193+
TORCH_CHECK(dst_ptrs.dtype() == torch::kUInt64, "dst_ptrs must be uint64");
1194+
TORCH_CHECK(sizes.dtype() == torch::kUInt64, "sizes must be uint64");
1195+
1196+
const int64_t n = src_ptrs.size(0);
1197+
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
1198+
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");
1199+
1200+
if (n == 0) return;
1201+
1202+
vllm::xpu::xpuAsyncMemcpyBatch(
1203+
src_ptrs.data_ptr<uint64_t>(),
1204+
dst_ptrs.data_ptr<uint64_t>(),
1205+
sizes.data_ptr<uint64_t>(),
1206+
n);
1207+
}
1208+
11781209
namespace vllm {
11791210

11801211
// Kernel for FP8 conversion (matches CUDA convert_fp8_kernel pattern).

csrc/ops.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,11 @@ void swap_blocks(
158158
int64_t block_size_in_bytes,
159159
const torch::Tensor& block_mapping);
160160

161+
void swap_blocks_batch(
162+
const torch::Tensor& src_ptrs,
163+
const torch::Tensor& dst_ptrs,
164+
const torch::Tensor& sizes);
165+
161166
void top_k_per_row_decode(
162167
const torch::Tensor& logits,
163168
int64_t next_n,

csrc/torch_bindings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
206206
"swap_blocks(Tensor src, Tensor! dst,"
207207
" int block_size_in_bytes, Tensor block_mapping) -> ()");
208208
cache_ops.impl("swap_blocks", torch::kXPU, &swap_blocks);
209+
// Batch swap: copies N (src_ptr, dst_ptr, size) triples in one call.
210+
// The target XPU device is auto-inferred from the device pointer.
211+
cache_ops.def(
212+
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
213+
" Tensor sizes) -> ()");
214+
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
209215
cache_ops.def(
210216
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache,"
211217
"Tensor slot_mapping, int quant_block_size, str scale_fmt) -> ()");

csrc/utils/mem_cpy.cpp

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,136 @@ void xpuAsyncMemcpy(
151151
}
152152
}
153153

154+
// Infer which XPU device a USM device pointer was allocated on by probing
155+
// each device's SYCL context. Returns the device index on success.
156+
// This is O(num_xpu_devices) but avoids threading an explicit device argument
157+
// through the entire call chain when all callers already have the pointer.
158+
static at::DeviceIndex infer_xpu_device_from_ptr(const void* device_ptr) {
159+
const int n_devs = c10::xpu::device_count();
160+
for (int i = 0; i < n_devs; i++) {
161+
auto ctx = vllm::xpu::vllmGetQueue(i).get_context();
162+
auto type = sycl::get_pointer_type(device_ptr, ctx);
163+
if (type == sycl::usm::alloc::device || type == sycl::usm::alloc::shared) {
164+
return static_cast<at::DeviceIndex>(i);
165+
}
166+
}
167+
TORCH_CHECK(false, "Cannot determine XPU device from pointer");
168+
return -1;
169+
}
170+
171+
void xpuAsyncMemcpyBatch(
172+
const uint64_t* src_ptrs,
173+
const uint64_t* dst_ptrs,
174+
const uint64_t* sizes,
175+
int64_t n) {
176+
if (n == 0) return;
177+
178+
// Scan the first non-zero entry to determine copy direction.
179+
// Also capture the device-side pointer so we can infer which XPU to use.
180+
const void* device_probe = nullptr;
181+
bool needs_staging = false;
182+
bool dst_is_pageable = false; // D2H to pageable host -> sync copy
183+
for (int64_t i = 0; i < n; i++) {
184+
if (sizes[i] == 0) continue;
185+
const void* first_src = reinterpret_cast<const void*>(src_ptrs[i]);
186+
const void* first_dst = reinterpret_cast<const void*>(dst_ptrs[i]);
187+
188+
// Use device 0's context as a probe: we only need pointer *type* here,
189+
// and USM pointer types are consistent across all devices on the same
190+
// platform (host/unknown are always host; device is always device on its
191+
// own platform). The actual device index is resolved below via
192+
// infer_xpu_device_from_ptr().
193+
auto probe_ctx = vllm::xpu::vllmGetQueue(0).get_context();
194+
auto src_type = sycl::get_pointer_type(first_src, probe_ctx);
195+
auto dst_type = sycl::get_pointer_type(first_dst, probe_ctx);
196+
bool src_is_host =
197+
(src_type == sycl::usm::alloc::host ||
198+
src_type == sycl::usm::alloc::unknown);
199+
bool dst_is_device = (dst_type == sycl::usm::alloc::device);
200+
needs_staging = src_is_host && dst_is_device;
201+
// D2H to pageable host requires synchronous copy to avoid corruption.
202+
dst_is_pageable = !dst_is_device && (dst_type == sycl::usm::alloc::unknown);
203+
// Device-side pointer: dst for H2D, src for D2H or D2D.
204+
device_probe = needs_staging ? first_dst : first_src;
205+
break;
206+
}
207+
208+
if (device_probe == nullptr) return; // all sizes are zero
209+
210+
// Infer the target XPU device from the device pointer and set the guard so
211+
// that vllmGetQueue() returns the correct in-order queue.
212+
const at::DeviceIndex dev = infer_xpu_device_from_ptr(device_probe);
213+
const at::DeviceGuard device_guard(at::Device(at::kXPU, dev));
214+
215+
auto& queue = vllm::xpu::vllmGetQueue();
216+
217+
// Compute total bytes needed for the H2D staging buffer.
218+
uint64_t total_bytes = 0;
219+
for (int64_t i = 0; i < n; i++) {
220+
total_bytes += sizes[i];
221+
}
222+
223+
if (needs_staging) {
224+
// H2D: allocate one contiguous pinned staging buffer, snapshot all source
225+
// blocks, then submit all async DMAs. This avoids N separate allocator
226+
// round-trips and protects against caller mutation after return.
227+
auto staging = at::getHostAllocator(at::kXPU)->allocate(
228+
static_cast<size_t>(total_bytes));
229+
char* staging_ptr = static_cast<char*>(staging.get());
230+
TORCH_CHECK(staging_ptr, "Failed to allocate pinned staging buffer");
231+
232+
// Phase 1: snapshot all source blocks into staging (pure CPU work).
233+
size_t staging_offset = 0;
234+
for (int64_t i = 0; i < n; i++) {
235+
size_t sz = static_cast<size_t>(sizes[i]);
236+
if (sz == 0) continue;
237+
std::memcpy(
238+
staging_ptr + staging_offset,
239+
reinterpret_cast<const void*>(src_ptrs[i]),
240+
sz);
241+
staging_offset += sz;
242+
}
243+
244+
// Phase 2: submit async DMA from staging to device in a tight loop,
245+
// maximising PCIe/copy-engine throughput without interleaved CPU work.
246+
staging_offset = 0;
247+
for (int64_t i = 0; i < n; i++) {
248+
size_t sz = static_cast<size_t>(sizes[i]);
249+
if (sz == 0) continue;
250+
queue.memcpy(
251+
reinterpret_cast<void*>(dst_ptrs[i]),
252+
staging_ptr + staging_offset,
253+
sz);
254+
staging_offset += sz;
255+
}
256+
257+
// Keep the staging buffer alive until all submitted DMAs complete.
258+
if (staging.get_context() != nullptr) {
259+
at::getHostAllocator(at::kXPU)->record_event(
260+
staging_ptr,
261+
const_cast<void*>(staging.get_context()),
262+
at::xpu::getCurrentXPUStream());
263+
}
264+
} else {
265+
// D2H or D2D: dst_is_pageable was probed once from the first non-zero
266+
// entry (all entries share the same direction and memory class).
267+
// Pageable D2H is unsafe with async DMA; fall back to sync copy.
268+
for (int64_t i = 0; i < n; i++) {
269+
size_t sz = static_cast<size_t>(sizes[i]);
270+
if (sz == 0) continue;
271+
272+
const void* src = reinterpret_cast<const void*>(src_ptrs[i]);
273+
void* dst = reinterpret_cast<void*>(dst_ptrs[i]);
274+
275+
if (dst_is_pageable) {
276+
queue.memcpy(dst, src, sz).wait();
277+
} else {
278+
queue.memcpy(dst, src, sz);
279+
}
280+
}
281+
}
282+
}
283+
154284
} // namespace xpu
155285
} // namespace vllm
156286

csrc/utils/mem_cpy.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <cstddef>
3+
#include <cstdint>
34

45
namespace vllm {
56
namespace xpu {
@@ -32,5 +33,27 @@ void xpuAsyncMemcpy(
3233
const void* hctx,
3334
bool is_pinned);
3435

36+
/**
37+
* @brief Batch async memcpy: copies N independent (src, dst, size) triples
38+
* in a single call, amortising per-copy overhead.
39+
*
40+
* The copy direction is auto-detected from the first non-zero entry's USM
41+
* pointer types. All entries must share the same direction.
42+
*
43+
* For H2D: snapshots all source blocks through a single contiguous pinned
44+
* staging buffer so the caller may safely mutate host memory immediately.
45+
* For D2H / D2D: direct async DMA without staging.
46+
*
47+
* @param src_ptrs Array of N raw source addresses
48+
* @param dst_ptrs Array of N raw destination addresses
49+
* @param sizes Array of N byte counts
50+
* @param n Number of entries
51+
*/
52+
void xpuAsyncMemcpyBatch(
53+
const uint64_t* src_ptrs,
54+
const uint64_t* dst_ptrs,
55+
const uint64_t* sizes,
56+
int64_t n);
57+
3558
} // namespace xpu
3659
} // namespace vllm

tests/register_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,17 @@ def swap_blocks(
487487
block_mapping)
488488

489489

490+
def swap_blocks_batch(
491+
src_ptrs: torch.Tensor,
492+
dst_ptrs: torch.Tensor,
493+
sizes: torch.Tensor,
494+
) -> None:
495+
"""Batch version of swap_blocks: copies N independent (src, dst, size)
496+
triples in a single call. The target XPU device is auto-inferred from the
497+
device-side pointers in src_ptrs/dst_ptrs."""
498+
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)
499+
500+
490501
def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
491502
token_expert_indices: torch.Tensor,
492503
gating_output: torch.Tensor, renormalize: bool,

0 commit comments

Comments
 (0)