Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 42 additions & 30 deletions csrc/cache_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -104,37 +104,49 @@ void swap_blocks_batch(const torch::Tensor& src_ptrs,
static_assert(sizeof(CUdeviceptr) == sizeof(int64_t));
static_assert(sizeof(size_t) == sizeof(int64_t));
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && CUDA_VERSION >= 12080
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
#if defined(CUDA_VERSION) && CUDA_VERSION >= 13000
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(dst_data),
reinterpret_cast<CUdeviceptr*>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
&attrs_idx, 1, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed with error ",
result);
#else
size_t fail_idx = 0;
CUresult result = cuMemcpyBatchAsync(
reinterpret_cast<CUdeviceptr*>(dst_data),
reinterpret_cast<CUdeviceptr*>(src_data),
reinterpret_cast<size_t*>(size_data), static_cast<size_t>(n), &attr,
&attrs_idx, 1, &fail_idx, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
#endif
#else
// Fallback for CUDA < 12.8 and ROCm: individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
// Resolve cuMemcpyBatchAsync at runtime via cuGetProcAddress so that
// binaries compiled with CUDA 12.8+ still work on older drivers, and
// we avoid the CUDA 13.0 header remapping (#define to _v2 signature).
// The function pointer is cached after the first call.
using BatchFn =
CUresult (*)(CUdeviceptr*, CUdeviceptr*, size_t*, size_t,
CUmemcpyAttributes*, size_t*, size_t, size_t*, CUstream);
static BatchFn batch_fn = []() -> BatchFn {
CUdriverProcAddressQueryResult sym_status;
void* fn_ptr = nullptr;
CUresult res = cuGetProcAddress("cuMemcpyBatchAsync", &fn_ptr, 12080,
CU_GET_PROC_ADDRESS_DEFAULT, &sym_status);
if (res != CUDA_SUCCESS || fn_ptr == nullptr) {
return nullptr;
}
return reinterpret_cast<BatchFn>(fn_ptr);
}();

if (batch_fn != nullptr) {
CUmemcpyAttributes attr = {};
attr.srcAccessOrder = CU_MEMCPY_SRC_ACCESS_ORDER_STREAM;
size_t attrs_idx = 0;
size_t fail_idx = 0;
CUresult result = batch_fn(reinterpret_cast<CUdeviceptr*>(dst_data),
reinterpret_cast<CUdeviceptr*>(src_data),
reinterpret_cast<size_t*>(size_data),
static_cast<size_t>(n), &attr, &attrs_idx, 1,
&fail_idx, static_cast<CUstream>(stream));
TORCH_CHECK(result == CUDA_SUCCESS, "cuMemcpyBatchAsync failed at index ",
fail_idx, " with error ", result);
} else
#endif
{
// Fallback for CUDA < 12.8, older drivers, and ROCm:
// individual async copies.
// cudaMemcpyDefault lets the driver infer direction from pointer types.
for (int64_t i = 0; i < n; i++) {
cudaMemcpyAsync(reinterpret_cast<void*>(dst_data[i]),
reinterpret_cast<void*>(src_data[i]),
static_cast<size_t>(size_data[i]), cudaMemcpyDefault,
stream);
}
}
}

namespace vllm {
Expand Down
Loading