diff --git a/csrc/cache.cpp b/csrc/cache.cpp index 3dd01a9e2..afff3daa1 100644 --- a/csrc/cache.cpp +++ b/csrc/cache.cpp @@ -1175,6 +1175,37 @@ void swap_blocks( return; } +/** + * @brief Batch version of swap_blocks: copies N independent (src, dst, size) + * triples in a single call, amortising per-copy overhead. + * + * Thin wrapper that validates the CPU tensor inputs and delegates to + * vllm::xpu::xpuAsyncMemcpyBatch for the actual copy logic. + */ +void swap_blocks_batch( + const torch::Tensor& src_ptrs, + const torch::Tensor& dst_ptrs, + const torch::Tensor& sizes) { + TORCH_CHECK(src_ptrs.device().is_cpu(), "src_ptrs must be on CPU"); + TORCH_CHECK(dst_ptrs.device().is_cpu(), "dst_ptrs must be on CPU"); + TORCH_CHECK(sizes.device().is_cpu(), "sizes must be on CPU"); + TORCH_CHECK(src_ptrs.dtype() == torch::kUInt64, "src_ptrs must be uint64"); + TORCH_CHECK(dst_ptrs.dtype() == torch::kUInt64, "dst_ptrs must be uint64"); + TORCH_CHECK(sizes.dtype() == torch::kUInt64, "sizes must be uint64"); + + const int64_t n = src_ptrs.size(0); + TORCH_CHECK(dst_ptrs.size(0) == n, "dst_ptrs length must match src_ptrs"); + TORCH_CHECK(sizes.size(0) == n, "sizes length must match src_ptrs"); + + if (n == 0) return; + + vllm::xpu::xpuAsyncMemcpyBatch( + src_ptrs.data_ptr(), + dst_ptrs.data_ptr(), + sizes.data_ptr(), + n); +} + namespace vllm { // Kernel for FP8 conversion (matches CUDA convert_fp8_kernel pattern). diff --git a/csrc/ops.h b/csrc/ops.h index 25d38e67d..9c138dda4 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -158,6 +158,11 @@ void swap_blocks( int64_t block_size_in_bytes, const torch::Tensor& block_mapping); +void swap_blocks_batch( + const torch::Tensor& src_ptrs, + const torch::Tensor& dst_ptrs, + const torch::Tensor& sizes); + void top_k_per_row_decode( const torch::Tensor& logits, int64_t next_n, diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 04e3480fe..4ae1b8ec3 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -206,6 +206,12 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { "swap_blocks(Tensor src, Tensor! dst," " int block_size_in_bytes, Tensor block_mapping) -> ()"); cache_ops.impl("swap_blocks", torch::kXPU, &swap_blocks); + // Batch swap: copies N (src_ptr, dst_ptr, size) triples in one call. + // The target XPU device is auto-inferred from the device pointer. + cache_ops.def( + "swap_blocks_batch(Tensor src_ptrs, Tensor dst_ptrs," + " Tensor sizes) -> ()"); + cache_ops.impl("swap_blocks_batch", torch::kCPU, &swap_blocks_batch); cache_ops.def( "indexer_k_quant_and_cache(Tensor k, Tensor! kv_cache," "Tensor slot_mapping, int quant_block_size, str scale_fmt) -> ()"); diff --git a/csrc/utils/mem_cpy.cpp b/csrc/utils/mem_cpy.cpp index 88998d9d4..661da04e6 100644 --- a/csrc/utils/mem_cpy.cpp +++ b/csrc/utils/mem_cpy.cpp @@ -151,6 +151,136 @@ void xpuAsyncMemcpy( } } +// Infer which XPU device a USM device pointer was allocated on by probing +// each device's SYCL context. Returns the device index on success. +// This is O(num_xpu_devices) but avoids threading an explicit device argument +// through the entire call chain when all callers already have the pointer. +static at::DeviceIndex infer_xpu_device_from_ptr(const void* device_ptr) { + const int n_devs = c10::xpu::device_count(); + for (int i = 0; i < n_devs; i++) { + auto ctx = vllm::xpu::vllmGetQueue(i).get_context(); + auto type = sycl::get_pointer_type(device_ptr, ctx); + if (type == sycl::usm::alloc::device || type == sycl::usm::alloc::shared) { + return static_cast(i); + } + } + TORCH_CHECK(false, "Cannot determine XPU device from pointer"); + return -1; +} + +void xpuAsyncMemcpyBatch( + const uint64_t* src_ptrs, + const uint64_t* dst_ptrs, + const uint64_t* sizes, + int64_t n) { + if (n == 0) return; + + // Scan the first non-zero entry to determine copy direction. + // Also capture the device-side pointer so we can infer which XPU to use. + const void* device_probe = nullptr; + bool needs_staging = false; + bool dst_is_pageable = false; // D2H to pageable host -> sync copy + for (int64_t i = 0; i < n; i++) { + if (sizes[i] == 0) continue; + const void* first_src = reinterpret_cast(src_ptrs[i]); + const void* first_dst = reinterpret_cast(dst_ptrs[i]); + + // Use device 0's context as a probe: we only need pointer *type* here, + // and USM pointer types are consistent across all devices on the same + // platform (host/unknown are always host; device is always device on its + // own platform). The actual device index is resolved below via + // infer_xpu_device_from_ptr(). + auto probe_ctx = vllm::xpu::vllmGetQueue(0).get_context(); + auto src_type = sycl::get_pointer_type(first_src, probe_ctx); + auto dst_type = sycl::get_pointer_type(first_dst, probe_ctx); + bool src_is_host = + (src_type == sycl::usm::alloc::host || + src_type == sycl::usm::alloc::unknown); + bool dst_is_device = (dst_type == sycl::usm::alloc::device); + needs_staging = src_is_host && dst_is_device; + // D2H to pageable host requires synchronous copy to avoid corruption. + dst_is_pageable = !dst_is_device && (dst_type == sycl::usm::alloc::unknown); + // Device-side pointer: dst for H2D, src for D2H or D2D. + device_probe = needs_staging ? first_dst : first_src; + break; + } + + if (device_probe == nullptr) return; // all sizes are zero + + // Infer the target XPU device from the device pointer and set the guard so + // that vllmGetQueue() returns the correct in-order queue. + const at::DeviceIndex dev = infer_xpu_device_from_ptr(device_probe); + const at::DeviceGuard device_guard(at::Device(at::kXPU, dev)); + + auto& queue = vllm::xpu::vllmGetQueue(); + + // Compute total bytes needed for the H2D staging buffer. + uint64_t total_bytes = 0; + for (int64_t i = 0; i < n; i++) { + total_bytes += sizes[i]; + } + + if (needs_staging) { + // H2D: allocate one contiguous pinned staging buffer, snapshot all source + // blocks, then submit all async DMAs. This avoids N separate allocator + // round-trips and protects against caller mutation after return. + auto staging = at::getHostAllocator(at::kXPU)->allocate( + static_cast(total_bytes)); + char* staging_ptr = static_cast(staging.get()); + TORCH_CHECK(staging_ptr, "Failed to allocate pinned staging buffer"); + + // Phase 1: snapshot all source blocks into staging (pure CPU work). + size_t staging_offset = 0; + for (int64_t i = 0; i < n; i++) { + size_t sz = static_cast(sizes[i]); + if (sz == 0) continue; + std::memcpy( + staging_ptr + staging_offset, + reinterpret_cast(src_ptrs[i]), + sz); + staging_offset += sz; + } + + // Phase 2: submit async DMA from staging to device in a tight loop, + // maximising PCIe/copy-engine throughput without interleaved CPU work. + staging_offset = 0; + for (int64_t i = 0; i < n; i++) { + size_t sz = static_cast(sizes[i]); + if (sz == 0) continue; + queue.memcpy( + reinterpret_cast(dst_ptrs[i]), + staging_ptr + staging_offset, + sz); + staging_offset += sz; + } + + // Keep the staging buffer alive until all submitted DMAs complete. + if (staging.get_context() != nullptr) { + at::getHostAllocator(at::kXPU)->record_event( + staging_ptr, + const_cast(staging.get_context()), + at::xpu::getCurrentXPUStream()); + } + } else { + // D2H or D2D: dst_is_pageable was probed once from the first non-zero + // entry (all entries share the same direction and memory class). + // Pageable D2H is unsafe with async DMA; fall back to sync copy. + for (int64_t i = 0; i < n; i++) { + size_t sz = static_cast(sizes[i]); + if (sz == 0) continue; + + const void* src = reinterpret_cast(src_ptrs[i]); + void* dst = reinterpret_cast(dst_ptrs[i]); + + if (dst_is_pageable) { + queue.memcpy(dst, src, sz).wait(); + } else { + queue.memcpy(dst, src, sz); + } + } + } +} + } // namespace xpu } // namespace vllm diff --git a/csrc/utils/mem_cpy.h b/csrc/utils/mem_cpy.h index ab274dc12..465ef4c1d 100644 --- a/csrc/utils/mem_cpy.h +++ b/csrc/utils/mem_cpy.h @@ -1,5 +1,6 @@ #pragma once #include +#include namespace vllm { namespace xpu { @@ -32,5 +33,27 @@ void xpuAsyncMemcpy( const void* hctx, bool is_pinned); +/** + * @brief Batch async memcpy: copies N independent (src, dst, size) triples + * in a single call, amortising per-copy overhead. + * + * The copy direction is auto-detected from the first non-zero entry's USM + * pointer types. All entries must share the same direction. + * + * For H2D: snapshots all source blocks through a single contiguous pinned + * staging buffer so the caller may safely mutate host memory immediately. + * For D2H / D2D: direct async DMA without staging. + * + * @param src_ptrs Array of N raw source addresses + * @param dst_ptrs Array of N raw destination addresses + * @param sizes Array of N byte counts + * @param n Number of entries + */ +void xpuAsyncMemcpyBatch( + const uint64_t* src_ptrs, + const uint64_t* dst_ptrs, + const uint64_t* sizes, + int64_t n); + } // namespace xpu } // namespace vllm diff --git a/tests/register_ops.py b/tests/register_ops.py index c5ee99d21..9feec089c 100644 --- a/tests/register_ops.py +++ b/tests/register_ops.py @@ -487,6 +487,17 @@ def swap_blocks( block_mapping) +def swap_blocks_batch( + src_ptrs: torch.Tensor, + dst_ptrs: torch.Tensor, + sizes: torch.Tensor, +) -> None: + """Batch version of swap_blocks: copies N independent (src, dst, size) + triples in a single call. The target XPU device is auto-inferred from the + device-side pointers in src_ptrs/dst_ptrs.""" + torch.ops._C_cache_ops.swap_blocks_batch(src_ptrs, dst_ptrs, sizes) + + def topk_sigmoid(topk_weights: torch.Tensor, topk_ids: torch.Tensor, token_expert_indices: torch.Tensor, gating_output: torch.Tensor, renormalize: bool, diff --git a/tests/test_cache.py b/tests/test_cache.py index dfcef8cc8..1f727c7d3 100644 --- a/tests/test_cache.py +++ b/tests/test_cache.py @@ -2,6 +2,7 @@ import random +import numpy as np import pytest import torch @@ -101,6 +102,16 @@ "device": ["xpu:0"], "kv_cache_dtype": KV_CACHE_DTYPE, }, + "test_swap_blocks_batch": { + "direction": [("cpu", "xpu")], + "device": ["xpu:0"], + }, + "test_swap_blocks_batch_empty": { + "device": ["xpu:0"], + }, + "test_swap_blocks_batch_h2d_mutation_race": { + "device": ["xpu:0"], + }, } @@ -948,3 +959,152 @@ def test_swap_blocks_mla( msg=f"Block {src} from src should have been swapped to block " f"{dst} in dst_cache.", ) + + +# --------------------------------------------------------------------------- +# swap_blocks_batch tests +# --------------------------------------------------------------------------- + + +def _build_batch_args( + src_cache: torch.Tensor, + dst_cache: torch.Tensor, + block_mapping: list[tuple[int, int]], + block_size_in_bytes: int, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Build (src_ptrs, dst_ptrs, sizes) tensors for swap_blocks_batch.""" + n = len(block_mapping) + src_arr = np.empty(n, dtype=np.uint64) + dst_arr = np.empty(n, dtype=np.uint64) + sz_arr = np.full(n, block_size_in_bytes, dtype=np.uint64) + + src_base = src_cache.data_ptr() + dst_base = dst_cache.data_ptr() + stride = src_cache.stride(0) * src_cache.element_size() + + for i, (sb, db) in enumerate(block_mapping): + src_arr[i] = src_base + sb * stride + dst_arr[i] = dst_base + db * stride + + return (torch.from_numpy(src_arr), torch.from_numpy(dst_arr), + torch.from_numpy(sz_arr)) + + +@pytest.mark.parametrize("direction", COPYING_DIRECTION) +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_swap_blocks_batch( + direction: tuple[str, str], + device: str, +) -> None: + """Test swap_blocks_batch for H2D, D2H and D2D directions.""" + num_mappings = 64 + num_heads = 8 + head_size = 64 + block_size = 8 + num_blocks = 256 + dtype = torch.bfloat16 + seed = 0 + + seed_everything(seed) + + src_device = device if direction[0] == "xpu" else "cpu" + dst_device = device if direction[1] == "xpu" else "cpu" + if "xpu" in direction: + torch.xpu.set_device(device) + + src_blocks = random.sample(range(num_blocks), num_mappings) + if src_device == dst_device: + remaining = list(set(range(num_blocks)) - set(src_blocks)) + dst_blocks = random.sample(remaining, num_mappings) + else: + dst_blocks = random.sample(range(num_blocks), num_mappings) + block_mapping = list(zip(src_blocks, dst_blocks)) + + src_key, src_val = create_kv_caches_with_random(num_blocks, block_size, 1, + num_heads, head_size, + "auto", dtype, seed, + src_device) + dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1, + num_heads, head_size, + "auto", dtype, seed, + dst_device) + + src_key_clone = src_key[0].clone() + src_val_clone = src_val[0].clone() + + block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0) + + # Build batch args and call + for src_cache, dst_cache in [(src_key[0], dst_key[0]), + (src_val[0], dst_val[0])]: + sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping, + block_size_in_bytes) + ops.swap_blocks_batch(sp, dp, sz) + + torch.xpu.synchronize() + + for sb, db in block_mapping: + torch.testing.assert_close(src_key_clone[sb].cpu(), + dst_key[0][db].cpu()) + torch.testing.assert_close(src_val_clone[sb].cpu(), + dst_val[0][db].cpu()) + + +@pytest.mark.parametrize("device", DEVICES) +@torch.inference_mode() +def test_swap_blocks_batch_h2d_mutation_race(device: str) -> None: + """Verify staging buffer protects against caller mutation for H2D batch.""" + num_mappings = 256 + num_heads = 8 + head_size = 128 + block_size = 32 + num_blocks = 512 + dtype = torch.bfloat16 + seed = 0 + + seed_everything(seed) + + src_blocks = random.sample(range(num_blocks), num_mappings) + dst_blocks = random.sample(range(num_blocks), num_mappings) + block_mapping = list(zip(src_blocks, dst_blocks)) + + # Source: pinned CPU memory + src_key, src_val = create_kv_caches_with_pinned(num_blocks, block_size, 1, + num_heads, head_size, + "auto", dtype, seed, "cpu") + assert src_key[0].is_pinned() + + # Destination: XPU + dst_key, dst_val = create_kv_caches_with_random(num_blocks, block_size, 1, + num_heads, head_size, + "auto", dtype, seed) + + src_key_clone = src_key[0].clone() + src_val_clone = src_val[0].clone() + + block_size_in_bytes = src_key[0].element_size() * src_key[0].stride(0) + + for src_cache, dst_cache in [(src_key[0], dst_key[0]), + (src_val[0], dst_val[0])]: + sp, dp, sz = _build_batch_args(src_cache, dst_cache, block_mapping, + block_size_in_bytes) + ops.swap_blocks_batch(sp, dp, sz) + + # Immediately mutate source — should not affect destination. + src_key[0].fill_(0) + src_val[0].fill_(0) + + torch.xpu.synchronize() + + for sb, db in block_mapping: + torch.testing.assert_close( + src_key_clone[sb].cpu(), + dst_key[0][db].cpu(), + msg=f"Key block {sb}→{db} corrupted by post-call mutation", + ) + torch.testing.assert_close( + src_val_clone[sb].cpu(), + dst_val[0][db].cpu(), + msg=f"Value block {sb}→{db} corrupted by post-call mutation", + )