diff --git a/csrc/cache_kernels.cu b/csrc/cache_kernels.cu index 7e456d32598b..895490f45a79 100644 --- a/csrc/cache_kernels.cu +++ b/csrc/cache_kernels.cu @@ -97,13 +97,13 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); - // Use cuMemcpyBatchAsync (CUDA 12.8+) to submit all copies in a single - // driver call, amortizing per-copy submission overhead. - // int64_t and CUdeviceptr/size_t are both 8 bytes on 64-bit platforms, - // so we reinterpret_cast the tensor data directly to avoid copies. - static_assert(sizeof(CUdeviceptr) == sizeof(int64_t)); + // Use cuMemcpyBatchAsync / hipMemcpyBatchAsync to submit all copies in a + // single driver call, amortizing per-copy submission overhead. int64_t + // and CUdeviceptr/void*/size_t are all 8 bytes on 64-bit platforms, so we + // reinterpret_cast the tensor data directly to avoid copies. static_assert(sizeof(size_t) == sizeof(int64_t)); #if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080 + static_assert(sizeof(CUdeviceptr) == sizeof(int64_t)); // Resolve cuMemcpyBatchAsync at runtime via cuGetProcAddress so that // binaries compiled with CUDA 12.8+ still work on older drivers, and // we avoid the CUDA 13.0 header remapping (#define to _v2 signature). @@ -134,12 +134,30 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs, &fail_idx, static_cast(stream)); TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ", fail_idx, " with error ", result); - } else + return; + } +#elif defined(USE_ROCM) && defined(HIP_VERSION) && HIP_VERSION >= 70100000 + // ROCm 7.1+ exposes hipMemcpyBatchAsync. The 7.2.1 implementation early- + // returns hipErrorNotSupported whenever numAttrs > 0 (see ROCm/clr @ + // rocm-7.2.1 hipamd/src/hip_memory.cpp:2819-2822), so call with + // numAttrs=0. + { + hipMemcpyAttributes attr = {}; + size_t attrs_idx = 0; + size_t fail_idx = 0; + hipError_t result = hipMemcpyBatchAsync( + reinterpret_cast(dst_data), reinterpret_cast(src_data), + reinterpret_cast(size_data), static_cast(n), &attr, + &attrs_idx, 0, &fail_idx, static_cast(stream)); + TORCH_CHECK(result == hipSuccess, "hipMemcpyBatchAsync failed at index ", + fail_idx, " with error ", result); + return; + } #endif { - // Fallback for CUDA < 12.8, older drivers, and ROCm: - // individual async copies. - // cudaMemcpyDefault lets the driver infer direction from pointer types. + // Fallback for CUDA < 12.8, older CUDA drivers, and ROCm < 7.1: + // individual async copies. cudaMemcpyDefault lets the driver infer + // direction from pointer types. for (int64_t i = 0; i < n; i++) { cudaMemcpyAsync(reinterpret_cast(dst_data[i]), reinterpret_cast(src_data[i]), diff --git a/tests/v1/simple_kv_offload/test_integration.py b/tests/v1/simple_kv_offload/test_integration.py index 29399516be18..02f6360e08e8 100644 --- a/tests/v1/simple_kv_offload/test_integration.py +++ b/tests/v1/simple_kv_offload/test_integration.py @@ -10,8 +10,8 @@ from vllm.config import KVTransferConfig from vllm.platforms import current_platform -if not current_platform.is_cuda(): - pytest.skip("Requires CUDA", allow_module_level=True) +if not current_platform.is_cuda_alike(): + pytest.skip("Requires CUDA or ROCm", allow_module_level=True) # Small models for default CI / local runs (accuracy only). SMALL_MODELS = [ diff --git a/tests/v1/simple_kv_offload/test_scheduler.py b/tests/v1/simple_kv_offload/test_scheduler.py index 132f52fe3b36..4d685103df60 100644 --- a/tests/v1/simple_kv_offload/test_scheduler.py +++ b/tests/v1/simple_kv_offload/test_scheduler.py @@ -186,7 +186,13 @@ def make_request( if request_id is None: request_id = f"req-{_req_counter}" - num_tokens = num_blocks * BLOCK_SIZE + # Add one extra token beyond the last full block so that + # ``max_cache_hit_length = num_tokens - 1`` (see + # KVCacheManager.get_computed_blocks) does not truncate the final + # full block: ``find_longest_cache_hit`` uses + # ``max_length // block_size`` and would otherwise drop one block + # when the prompt is an exact multiple of block_size. + num_tokens = num_blocks * BLOCK_SIZE + 1 start = _req_counter * 10000 prompt_token_ids = list(range(start, start + num_tokens)) sampling_params = SamplingParams(max_tokens=1) diff --git a/vllm/v1/simple_kv_offload/cuda_mem_ops.py b/vllm/v1/simple_kv_offload/cuda_mem_ops.py index 03338421c457..b4c68aff3ca9 100644 --- a/vllm/v1/simple_kv_offload/cuda_mem_ops.py +++ b/vllm/v1/simple_kv_offload/cuda_mem_ops.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -"""Low-level CUDA memory helpers: pinning and batch DMA transfers.""" +"""Low-level CUDA/HIP memory helpers: pinning and batch DMA transfers.""" import ctypes from typing import Any, NamedTuple @@ -9,6 +9,7 @@ import torch from vllm.logger import init_logger +from vllm.platforms import current_platform logger = init_logger(__name__) @@ -39,7 +40,7 @@ class _CUmemcpyAttributes(ctypes.Structure): _BATCH_MEMCPY_FUNC_TYPE = ctypes.CFUNCTYPE( - ctypes.c_uint, # CUresult + ctypes.c_uint, # CUresult / hipError_t ctypes.c_void_p, ctypes.c_void_p, ctypes.c_void_p, @@ -56,7 +57,42 @@ class _CUmemcpyAttributes(ctypes.Structure): def _resolve_batch_memcpy(): - """Resolve cuMemcpyBatchAsync via cuGetProcAddress (one-time).""" + """Resolve the platform batch-memcpy entry point (one-time). + + * CUDA: ``cuMemcpyBatchAsync`` via ``cuGetProcAddress`` (uses + srcAccessOrder=STREAM via one attributes entry). + * ROCm: ``hipMemcpyBatchAsync`` from libamdhip64 (ROCm 7.1+). ROCm + 7.2.1 or 7.2.2 rejects any call with ``numAttrs > 0`` + (see ROCm/clr @ rocm-7.2.1 hipamd/src/hip_memory.cpp:2819-2822), so + we call with ``numAttrs=0``. + + Raises ``RuntimeError`` if the symbol is unavailable (older CUDA + driver, ROCm < 7.1, unusual install). The connector requires the + batch API. + """ + if current_platform.is_rocm(): + try: + lib = ctypes.CDLL("libamdhip64.so", mode=ctypes.RTLD_GLOBAL) + fn = lib.hipMemcpyBatchAsync + except (OSError, AttributeError) as e: + raise RuntimeError( + "hipMemcpyBatchAsync is unavailable in this ROCm install; " + "SimpleCPUOffloadConnector requires ROCm 7.1+." + ) from e + fn.restype = ctypes.c_uint + fn.argtypes = [ + ctypes.c_void_p, # dsts + ctypes.c_void_p, # srcs + ctypes.c_void_p, # sizes + ctypes.c_size_t, # count + ctypes.c_void_p, # attrs + ctypes.c_void_p, # attrIdxs + ctypes.c_size_t, # numAttrs + ctypes.c_void_p, # failIdx + ctypes.c_void_p, # stream + ] + return fn + from cuda.bindings import driver as drv err, ptr, _ = drv.cuGetProcAddress(b"cuMemcpyBatchAsync", 12080, 0) @@ -70,6 +106,8 @@ class BatchMemcpyParams(NamedTuple): dst_bases: np.ndarray # [num_layers] uint64 bpb: np.ndarray # [num_layers] uint64 — bytes per block num_layers: int + # CUDA only: one attributes entry with srcAccessOrder=ANY. Unused on + # ROCm (7.2.1 or 7.2.2) because the current runtime rejects numAttrs > 0. attrs: _CUmemcpyAttributes attrs_idx: ctypes.c_size_t # NOTE: cuMemcpyBatchAsync_v2() removed fail_idx field, but we use @@ -99,8 +137,10 @@ def build_params( dst_bases.append(d.data_ptr()) bpb.append(s_bpb) - # Refer to https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f for details. # noqa: E501 - attrs = _CUmemcpyAttributes(srcAccessOrder=3) # ANY + # ``srcAccessOrder=3`` == CU_MEMCPY_SRC_ACCESS_ORDER_ANY / + # hipMemcpySrcAccessOrderAny. See + # https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g6f1ff58e3065df3eb4b573dba77ad31f # noqa: E501 + attrs = _CUmemcpyAttributes(srcAccessOrder=3) return BatchMemcpyParams( src_bases=np.array(src_bases, dtype=np.uint64), @@ -119,7 +159,7 @@ def copy_blocks( dst_block_ids: list[int], params: BatchMemcpyParams, ) -> None: - """Copy blocks via cuMemcpyBatchAsync.""" + """Copy blocks via cuMemcpyBatchAsync / hipMemcpyBatchAsync.""" n = len(src_block_ids) if n == 0: return @@ -134,8 +174,13 @@ def copy_blocks( params.dst_bases[:, None] + dst_ids[None, :] * params.bpb[:, None] ).ravel() sz_all = np.repeat(params.bpb, n) - total = n * params.num_layers + + # ROCm 7.2.1/7.2.2 rejects any call with numAttrs>0 (hipMemcpyBatchAsync + # hipamd/src/hip_memory.cpp:2819-2822); CUDA uses one attrs entry so + # srcAccessOrder is honored. attrs / attrsIdxs are ignored when + # numAttrs==0, so we pass the same values from both paths. + num_attrs = 0 if current_platform.is_rocm() else 1 err = _batch_memcpy_fn( dst_all.ctypes.data, src_all.ctypes.data, @@ -143,11 +188,11 @@ def copy_blocks( total, ctypes.addressof(params.attrs), ctypes.byref(params.attrs_idx), - 1, + num_attrs, ctypes.byref(params.fail_idx), params.stream_handle, ) if err != 0: raise RuntimeError( - f"cuMemcpyBatchAsync failed: err={err} failIdx={params.fail_idx.value}" + f"batch memcpy failed: err={err} failIdx={params.fail_idx.value}" )