Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -97,13 +97,13 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,

const cudaStream_t stream = at::cuda::getCurrentCUDAStream();

// Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single
// driver call, amortizing per-copy submission overhead.
// int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms,
// so we reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
// Use cuMemcpyBatchAsync / hipMemcpyBatchAsync to submit all copies in a
// single driver call, amortizing per-copy submission overhead. int64_t
// and CUdeviceptr/void*/size_t are all 8 bytes on 64-bit platforms, so we
// reinterpret_cast the tensor data directly to avoid copies.
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
// Resolve cuMemcpyBatchAsync at runtime via cuGetProcAddress so that
// binaries compiled with CUDA 12.8+ still work on older drivers, and
// we avoid the CUDA 13.0 header remapping (#define to _v2 signature).
Expand Down Expand Up @@ -134,12 +134,30 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
&fail_idx, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
} else
return;
}
#elif defined(USE_ROCM) && defined(HIP_VERSION) && HIP_VERSION >= 70100000
// ROCm 7.1+ exposes hipMemcpyBatchAsync. The 7.2.1 implementation early-
// returns hipErrorNotSupported whenever numAttrs > 0 (see ROCm/clr @
// rocm-7.2.1 hipamd/src/hip_memory.cpp:2819-2822), so call with
// numAttrs=0.
{
hipMemcpyAttributes attr = {};
size_t attrs_idx = 0;
size_t fail_idx = 0;
hipError_t result = hipMemcpyBatchAsync(
reinterpret_cast<void**>(dst_data), reinterpret_cast<void**>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
&attrs_idx, 0, &fail_idx, static_cast<hipStream_t>(stream));
TORCH_CHECK(result == hipSuccess, "hipMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
return;
}
#endif
{
// Fallback for CUDA < 12.8, older drivers, and ROCm:
// individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
// Fallback for CUDA < 12.8, older CUDA drivers, and ROCm < 7.1:
// individual async copies. cudaMemcpyDefault lets the driver infer
// direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
Expand Down
4 changes: 2 additions & 2 deletions tests/v1/simple_kv_offload/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from vllm.config import KVTransferConfig
from vllm.platforms import current_platform

if not current_platform.is_cuda():
pytest.skip("Requires CUDA", allow_module_level=True)
if not current_platform.is_cuda_alike():
pytest.skip("Requires CUDA or ROCm", allow_module_level=True)

# Small models for default CI / local runs (accuracy only).
SMALL_MODELS = [
Expand Down
8 changes: 7 additions & 1 deletion tests/v1/simple_kv_offload/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,13 @@ def make_request(
if request_id is None:
request_id = f"req-{_req_counter}"

num_tokens = num_blocks * BLOCK_SIZE
# Add one extra token beyond the last full block so that
# ``max_cache_hit_length = num_tokens - 1`` (see
# KVCacheManager.get_computed_blocks) does not truncate the final
# full block: ``find_longest_cache_hit`` uses
# ``max_length // block_size`` and would otherwise drop one block
# when the prompt is an exact multiple of block_size.
num_tokens = num_blocks * BLOCK_SIZE + 1
start = _req_counter * 10000
prompt_token_ids = list(range(start, start + num_tokens))
sampling_params = SamplingParams(max_tokens=1)
Expand Down
69 changes: 58 additions & 11 deletions vllm/v1/simple_kv_offload/cuda_mem_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Low-level CUDA memory helpers: pinning and batch DMA transfers."""
"""Low-level CUDA/HIP memory helpers: pinning and batch DMA transfers."""

import ctypes
from typing import Any, NamedTuple
Expand All @@ -9,12 +9,13 @@
import torch

from vllm.logger import init_logger
from vllm.platforms import current_platform

logger = init_logger(__name__)


def pin_tensor(tensor: torch.Tensor) -> None:
"""Pin a CPU tensor via cudaHostRegister.
"""Pin a CPU tensor via cudaHostRegister / hipHostRegister.

This bypasses PyTorch's CUDACachingHostAllocator which rounds
every ``pin_memory=True`` allocation up to the next power of 2
Expand All @@ -25,6 +26,8 @@ def pin_tensor(tensor: torch.Tensor) -> None:
raise RuntimeError(f"cudaHostRegister failed: {err}")


# NOTE: ``CUmemcpyAttributes`` and ``hipMemcpyAttributes`` share the same
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to specify this as it is known that CUmemcpyAttributes and hipMemcpyAttributes are compatible if we don't specify custom code path. We don't need to be this verbose.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed

# layout, so a single ctypes struct definition works for both.
class _CUmemLocation(ctypes.Structure):
_fields_ = [("type", ctypes.c_uint), ("id", ctypes.c_int)]

Expand All @@ -39,7 +42,7 @@ class _CUmemcpyAttributes(ctypes.Structure):


_BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE(
ctypes.c_uint, # CUresult
ctypes.c_uint, # CUresult / hipError_t
ctypes.c_void_p,
ctypes.c_void_p,
ctypes.c_void_p,
Expand All @@ -56,7 +59,42 @@ class _CUmemcpyAttributes(ctypes.Structure):


def _resolve_batch_memcpy():
"""Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time)."""
"""Resolve the platform batch-memcpy entry point (one-time).

* CUDA: ``cuMemcpyBatchAsync`` via ``cuGetProcAddress`` (uses
srcAccessOrder=STREAM via one attributes entry).
* ROCm: ``hipMemcpyBatchAsync`` from libamdhip64 (ROCm 7.1+). ROCm
7.2.1 rejects any call with ``numAttrs > 0``
(see ROCm/clr @ rocm-7.2.1 hipamd/src/hip_memory.cpp:2819-2822), so
we call with ``numAttrs=0``.

Raises ``RuntimeError`` if the symbol is unavailable (older CUDA
driver, ROCm < 7.1, unusual install). The connector requires the
batch API.
"""
if current_platform.is_rocm():
try:
lib = ctypes.CDLL("libamdhip64.so", mode=ctypes.RTLD_GLOBAL)
fn = lib.hipMemcpyBatchAsync
except (OSError, AttributeError) as e:
raise RuntimeError(
"hipMemcpyBatchAsync is unavailable in this ROCm install; "
"SimpleCPUOffloadConnector requires ROCm 7.1+."
) from e
fn.restype = ctypes.c_uint
fn.argtypes = [
ctypes.c_void_p, # dsts
ctypes.c_void_p, # srcs
ctypes.c_void_p, # sizes
ctypes.c_size_t, # count
ctypes.c_void_p, # attrs
ctypes.c_void_p, # attrIdxs
ctypes.c_size_t, # numAttrs
ctypes.c_void_p, # failIdx
ctypes.c_void_p, # stream
]
return fn
Comment thread
hongxiayang marked this conversation as resolved.

from cuda.bindings import driver as drv

err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0)
Expand All @@ -70,12 +108,14 @@ class BatchMemcpyParams(NamedTuple):
dst_bases: np.ndarray # [num_layers] uint64
bpb: np.ndarray # [num_layers] uint64 — bytes per block
num_layers: int
# CUDA only: one attributes entry with srcAccessOrder=ANY. Unused on
# ROCm because the current runtime rejects numAttrs > 0.
attrs: _CUmemcpyAttributes
attrs_idx: ctypes.c_size_t
# NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use
# cuMemcpyBatchAsync() with fail_idx for backward compatibility
fail_idx: ctypes.c_size_t
stream_handle: int # raw cudaStream_t / CUstream
stream_handle: int # raw cudaStream_t / CUstream / hipStream_t


def build_params(
Expand All @@ -99,8 +139,10 @@ def build_params(
dst_bases.append(d.data_ptr())
bpb.append(s_bpb)

# Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501
attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY
# ``srcAccessOrder=3`` == CU_MEMCPY_SRC_ACCESS_ORDER_ANY /
# hipMemcpySrcAccessOrderAny. See
# https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f # noqa: E501
attrs = _CUmemcpyAttributes(srcAccessOrder=3)

return BatchMemcpyParams(
src_bases=np.array(src_bases, dtype=np.uint64),
Expand All @@ -119,7 +161,7 @@ def copy_blocks(
dst_block_ids: list[int],
params: BatchMemcpyParams,
) -> None:
"""Copy blocks via cuMemcpyBatchAsync."""
"""Copy blocks via cuMemcpyBatchAsync / hipMemcpyBatchAsync."""
n = len(src_block_ids)
if n == 0:
return
Expand All @@ -134,20 +176,25 @@ def copy_blocks(
params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None]
).ravel()
sz_all = np.repeat(params.bpb, n)

total = n * params.num_layers

# ROCm 7.2.1 rejects any call with numAttrs>0 (hipMemcpyBatchAsync
# hipamd/src/hip_memory.cpp:2819-2822); CUDA uses one attrs entry so
# srcAccessOrder is honored. attrs / attrsIdxs are ignored when
# numAttrs==0, so we pass the same values from both paths.
num_attrs = 0 if current_platform.is_rocm() else 1
err = _batch_memcpy_fn(
dst_all.ctypes.data,
src_all.ctypes.data,
sz_all.ctypes.data,
total,
ctypes.addressof(params.attrs),
ctypes.byref(params.attrs_idx),
1,
num_attrs,
ctypes.byref(params.fail_idx),
params.stream_handle,
)
if err != 0:
raise RuntimeError(
f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}"
f"batch memcpy failed: err={err} failIdx={params.fail_idx.value}"
)
Loading