Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions csrc/cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1175,6 +1175,37 @@ void swap_blocks(
return;
}

/**
* @brief Batch version of swap_blocks: copies N independent (src, dst, size)
* triples in a single call, amortising per-copy overhead.
*
* Thin wrapper that validates the CPU tensor inputs and delegates to
* vllm::xpu::xpuAsyncMemcpyBatch for the actual copy logic.
*/
void swap_blocks_batch(
const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes) {
TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU");
TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU");
TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU");
TORCH_CHECK(src_ptrs.dtype() == torch::kUInt64, "src_ptrs must be uint64");
TORCH_CHECK(dst_ptrs.dtype() == torch::kUInt64, "dst_ptrs must be uint64");
TORCH_CHECK(sizes.dtype() == torch::kUInt64, "sizes must be uint64");
Comment thread
chaojun-zhang marked this conversation as resolved.

const int64_t n = src_ptrs.size(0);
TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs");
TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs");

if (n == 0) return;

vllm::xpu::xpuAsyncMemcpyBatch(
src_ptrs.data_ptr<uint64_t>(),
dst_ptrs.data_ptr<uint64_t>(),
sizes.data_ptr<uint64_t>(),
n);
Comment thread
chaojun-zhang marked this conversation as resolved.
}

namespace vllm {

// Kernel for FP8 conversion (matches CUDA convert_fp8_kernel pattern).
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,11 @@ void swap_blocks(
int64_t block_size_in_bytes,
const torch::Tensor& block_mapping);

void swap_blocks_batch(
const torch::Tensor& src_ptrs,
const torch::Tensor& dst_ptrs,
const torch::Tensor& sizes);

void top_k_per_row_decode(
const torch::Tensor& logits,
int64_t next_n,
Expand Down
6 changes: 6 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
"swap_blocks(Tensor src, Tensor! dst,"
" int block_size_in_bytes, Tensor block_mapping) -> ()");
cache_ops.impl("swap_blocks", torch::kXPU, &swap_blocks);
// Batch swap: copies N (src_ptr, dst_ptr, size) triples in one call.
// The target XPU device is auto-inferred from the device pointer.
cache_ops.def(
"swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs,"
" Tensor sizes) -> ()");
cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch);
Comment thread
chaojun-zhang marked this conversation as resolved.
cache_ops.def(
"indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache,"
"Tensor slot_mapping, int quant_block_size, str scale_fmt) -> ()");
Expand Down
130 changes: 130 additions & 0 deletions csrc/utils/mem_cpy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,136 @@ void xpuAsyncMemcpy(
}
}

// Infer which XPU device a USM device pointer was allocated on by probing
// each device's SYCL context. Returns the device index on success.
// This is O(num_xpu_devices) but avoids threading an explicit device argument
// through the entire call chain when all callers already have the pointer.
static at::DeviceIndex infer_xpu_device_from_ptr(const void* device_ptr) {
const int n_devs = c10::xpu::device_count();
for (int i = 0; i < n_devs; i++) {
auto ctx = vllm::xpu::vllmGetQueue(i).get_context();
auto type = sycl::get_pointer_type(device_ptr, ctx);
if (type == sycl::usm::alloc::device || type == sycl::usm::alloc::shared) {
return static_cast<at::DeviceIndex>(i);
}
}
TORCH_CHECK(false, "Cannot determine XPU device from pointer");
return -1;
}

void xpuAsyncMemcpyBatch(
const uint64_t* src_ptrs,
const uint64_t* dst_ptrs,
const uint64_t* sizes,
int64_t n) {
if (n == 0) return;

// Scan the first non-zero entry to determine copy direction.
// Also capture the device-side pointer so we can infer which XPU to use.
const void* device_probe = nullptr;
bool needs_staging = false;
bool dst_is_pageable = false; // D2H to pageable host -> sync copy
for (int64_t i = 0; i < n; i++) {
if (sizes[i] == 0) continue;
const void* first_src = reinterpret_cast<const void*>(src_ptrs[i]);
const void* first_dst = reinterpret_cast<const void*>(dst_ptrs[i]);

// Use device 0's context as a probe: we only need pointer *type* here,
// and USM pointer types are consistent across all devices on the same
// platform (host/unknown are always host; device is always device on its
// own platform). The actual device index is resolved below via
// infer_xpu_device_from_ptr().
auto probe_ctx = vllm::xpu::vllmGetQueue(0).get_context();
auto src_type = sycl::get_pointer_type(first_src, probe_ctx);
auto dst_type = sycl::get_pointer_type(first_dst, probe_ctx);
bool src_is_host =
(src_type == sycl::usm::alloc::host ||
src_type == sycl::usm::alloc::unknown);
bool dst_is_device = (dst_type == sycl::usm::alloc::device);
needs_staging = src_is_host && dst_is_device;
// D2H to pageable host requires synchronous copy to avoid corruption.
dst_is_pageable = !dst_is_device && (dst_type == sycl::usm::alloc::unknown);
// Device-side pointer: dst for H2D, src for D2H or D2D.
device_probe = needs_staging ? first_dst : first_src;
break;
}

if (device_probe == nullptr) return; // all sizes are zero

// Infer the target XPU device from the device pointer and set the guard so
// that vllmGetQueue() returns the correct in-order queue.
const at::DeviceIndex dev = infer_xpu_device_from_ptr(device_probe);
const at::DeviceGuard device_guard(at::Device(at::kXPU, dev));

auto& queue = vllm::xpu::vllmGetQueue();

// Compute total bytes needed for the H2D staging buffer.
uint64_t total_bytes = 0;
for (int64_t i = 0; i < n; i++) {
total_bytes += sizes[i];
}

if (needs_staging) {
// H2D: allocate one contiguous pinned staging buffer, snapshot all source
// blocks, then submit all async DMAs. This avoids N separate allocator
// round-trips and protects against caller mutation after return.
auto staging = at::getHostAllocator(at::kXPU)->allocate(
static_cast<size_t>(total_bytes));
char* staging_ptr = static_cast<char*>(staging.get());
TORCH_CHECK(staging_ptr, "Failed to allocate pinned staging buffer");

// Phase 1: snapshot all source blocks into staging (pure CPU work).
size_t staging_offset = 0;
for (int64_t i = 0; i < n; i++) {
size_t sz = static_cast<size_t>(sizes[i]);
if (sz == 0) continue;
std::memcpy(
staging_ptr + staging_offset,
reinterpret_cast<const void*>(src_ptrs[i]),
sz);
staging_offset += sz;
}

// Phase 2: submit async DMA from staging to device in a tight loop,
// maximising PCIe/copy-engine throughput without interleaved CPU work.
staging_offset = 0;
for (int64_t i = 0; i < n; i++) {
size_t sz = static_cast<size_t>(sizes[i]);
if (sz == 0) continue;
queue.memcpy(
reinterpret_cast<void*>(dst_ptrs[i]),
staging_ptr + staging_offset,
sz);
staging_offset += sz;
}

// Keep the staging buffer alive until all submitted DMAs complete.
if (staging.get_context() != nullptr) {
at::getHostAllocator(at::kXPU)->record_event(
staging_ptr,
const_cast<void*>(staging.get_context()),
at::xpu::getCurrentXPUStream());
}
} else {
// D2H or D2D: dst_is_pageable was probed once from the first non-zero
// entry (all entries share the same direction and memory class).
// Pageable D2H is unsafe with async DMA; fall back to sync copy.
for (int64_t i = 0; i < n; i++) {
size_t sz = static_cast<size_t>(sizes[i]);
if (sz == 0) continue;

const void* src = reinterpret_cast<const void*>(src_ptrs[i]);
void* dst = reinterpret_cast<void*>(dst_ptrs[i]);

Comment thread
chaojun-zhang marked this conversation as resolved.
if (dst_is_pageable) {
queue.memcpy(dst, src, sz).wait();
} else {
queue.memcpy(dst, src, sz);
}
}
}
}

} // namespace xpu
} // namespace vllm

Expand Down
23 changes: 23 additions & 0 deletions csrc/utils/mem_cpy.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once
#include <cstddef>
#include <cstdint>

namespace vllm {
namespace xpu {
Expand Down Expand Up @@ -32,5 +33,27 @@ void xpuAsyncMemcpy(
const void* hctx,
bool is_pinned);

/**
* @brief Batch async memcpy: copies N independent (src, dst, size) triples
* in a single call, amortising per-copy overhead.
*
* The copy direction is auto-detected from the first non-zero entry's USM
* pointer types. All entries must share the same direction.
*
* For H2D: snapshots all source blocks through a single contiguous pinned
* staging buffer so the caller may safely mutate host memory immediately.
* For D2H / D2D: direct async DMA without staging.
*
* @param src_ptrs Array of N raw source addresses
* @param dst_ptrs Array of N raw destination addresses
* @param sizes Array of N byte counts
* @param n Number of entries
*/
void xpuAsyncMemcpyBatch(
const uint64_t* src_ptrs,
const uint64_t* dst_ptrs,
const uint64_t* sizes,
int64_t n);

} // namespace xpu
} // namespace vllm
11 changes: 11 additions & 0 deletions tests/register_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,17 @@ def swap_blocks(
block_mapping)


def swap_blocks_batch(
src_ptrs: torch.Tensor,
dst_ptrs: torch.Tensor,
sizes: torch.Tensor,
) -> None:
"""Batch version of swap_blocks: copies N independent (src, dst, size)
triples in a single call. The target XPU device is auto-inferred from the
device-side pointers in src_ptrs/dst_ptrs."""
torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes)


def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor,
token_expert_indices: torch.Tensor,
gating_output: torch.Tensor, renormalize: bool,
Expand Down
Loading
Loading