Skip to content

cuda : improve cuda pool efficiency using virtual memory #4606

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 22 commits into from
Dec 24, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ endif # LLAMA_BLIS

ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include -I/usr/local/cuda/targets/aarch64-linux/include
MK_LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
MK_LDFLAGS += -lcuda -L/usr/lib/wsl/lib -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib -L/usr/local/cuda/targets/aarch64-linux/lib
OBJS += ggml-cuda.o
MK_NVCCFLAGS = -use_fast_math

ifndef JETSON_EOL_MODULE_DETECT
MK_NVCCFLAGS += --forward-unknown-to-host-compiler
endif # JETSON_EOL_MODULE_DETECT
Expand Down
173 changes: 148 additions & 25 deletions ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@
#define __trap abort
#else
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
// CUDA 10.2 does not have these macro definitions.
Expand Down Expand Up @@ -213,6 +214,24 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \
} while (0)

// driver API
#define CU_CHECK(err) \
do { \
CUresult err_ = (err); \
if (err_ != CUDA_SUCCESS) { \
int id; \
cuDeviceGet(&id, 0); \
const char * err_str; \
cuGetErrorString(err_, &err_str); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
err_str); \
fprintf(stderr, "%s\n", #err); \
fprintf(stderr, "current device: %d\n", id); \
GGML_ASSERT(!"CUDA error"); \
} \
} while (0)


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In my test implementation for peer access I was able to simply use CUDA_CHECK for the CUresult. I don't know whether that works universally though.

Copy link
Member Author

@slaren slaren Dec 22, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It didn't work for me, it is a different type and requires a different function to get the error description, cuGetErrorString.

#if CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) \
do { \
Expand Down Expand Up @@ -6543,21 +6562,26 @@ struct scoped_spin_lock {
scoped_spin_lock& operator=(const scoped_spin_lock&) = delete;
};

static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;

#if 0
#define DEBUG_CUDA_MALLOC
struct cuda_buffer {
void * ptr = nullptr;
size_t size = 0;
};

static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;

static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};

static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));
#ifdef DEBUG_CUDA_MALLOC
int nnz = 0;
size_t max_size = 0, tot_size = 0;
size_t max_size = 0;
#endif
size_t best_diff = 1ull << 36;
int ibest = -1;
Expand All @@ -6566,7 +6590,6 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
if (b.ptr != nullptr) {
#ifdef DEBUG_CUDA_MALLOC
++nnz;
tot_size += b.size;
if (b.size > max_size) max_size = b.size;
#endif
if (b.size >= size) {
Expand All @@ -6593,15 +6616,16 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
b.size = 0;
return ptr;
}
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s: %d buffers, max_size = %u MB, tot_size = %u MB, requested %u MB\n", __func__, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(tot_size/1024/1024), (uint32_t)(size/1024/1024));
#endif
void * ptr;
size_t look_ahead_size = (size_t) (1.05 * size);
look_ahead_size = 256 * ((look_ahead_size + 255)/256);
CUDA_CHECK(cudaMalloc((void **) &ptr, look_ahead_size));
*actual_size = look_ahead_size;
g_cuda_pool_size[id] += look_ahead_size;
#ifdef DEBUG_CUDA_MALLOC
fprintf(stderr, "%s[%d]: %d buffers, max_size = %u MB, pool_size = %u MB, requested %u MB\n", __func__, id, nnz,
(uint32_t)(max_size/1024/1024), (uint32_t)(g_cuda_pool_size[id]/1024/1024), (uint32_t)(size/1024/1024));
#endif
return ptr;
}

Expand All @@ -6620,8 +6644,107 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
}
fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n");
CUDA_CHECK(cudaFree(ptr));
g_cuda_pool_size[id] -= size;
}
#else

static std::vector<CUmemGenericAllocationHandle> g_cuda_pool_handles[GGML_CUDA_MAX_DEVICES];
static CUdeviceptr g_cuda_pool_addr[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_size[GGML_CUDA_MAX_DEVICES] = {0};
static size_t g_cuda_pool_used[GGML_CUDA_MAX_DEVICES] = {0};

static const size_t CUDA_POOL_MAX_SIZE = 1ull << 36; // 64 GB

//#define DEBUG_CUDA_MALLOC

#define ggml_cuda_pool_malloc(size, actual_size) ggml_cuda_pool_malloc_(size, actual_size, #size " " #actual_size)
static void * ggml_cuda_pool_malloc_(size_t size, size_t * actual_size, const char * call) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));

size_t avail = g_cuda_pool_size[id] - g_cuda_pool_used[id];

if (size > avail) {
size_t reserve_size = size - avail;

// allocate more physical memory
CUmemAllocationProp prop = {};
prop.type = CU_MEM_ALLOCATION_TYPE_PINNED;
prop.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
prop.location.id = id;

// get the minimum allocation granularity for this device
size_t granularity = 0;
CU_CHECK(cuMemGetAllocationGranularity(&granularity, &prop, CU_MEM_ALLOC_GRANULARITY_MINIMUM));

// round up to the nearest granularity
reserve_size = granularity * ((reserve_size + granularity - 1) / granularity);

GGML_ASSERT(g_cuda_pool_size[id] + reserve_size <= CUDA_POOL_MAX_SIZE);

CUmemGenericAllocationHandle handle;
CU_CHECK(cuMemCreate(&handle, reserve_size, &prop, 0));

// reserve virtual address space (if not already reserved)
if (g_cuda_pool_addr[id] == 0) {
CU_CHECK(cuMemAddressReserve(&g_cuda_pool_addr[id], CUDA_POOL_MAX_SIZE, 0, 0, 0));
}

// map at the end of the pool
CU_CHECK(cuMemMap(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, 0, handle, 0));

// set access
CUmemAccessDesc access = {};
access.location.type = CU_MEM_LOCATION_TYPE_DEVICE;
access.location.id = id;
access.flags = CU_MEM_ACCESS_FLAGS_PROT_READWRITE;
CU_CHECK(cuMemSetAccess(g_cuda_pool_addr[id] + g_cuda_pool_size[id], reserve_size, &access, 1));

// add to the pool
g_cuda_pool_handles[id].push_back(handle);
g_cuda_pool_size[id] += reserve_size;

printf("cuda pool[%d]: size increased to %llu MB (reserved %llu MB) [%s]\n",
id, (unsigned long long) (g_cuda_pool_size[id]/1024/1024),
(unsigned long long) (reserve_size/1024/1024), call);
}

GGML_ASSERT(g_cuda_pool_addr[id] != 0);

void * ptr = (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]);
*actual_size = size;
g_cuda_pool_used[id] += size;

#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: allocated %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr, call);
#endif

return ptr;

GGML_UNUSED(call);
}

#define ggml_cuda_pool_free(ptr, size) ggml_cuda_pool_free_(ptr, size, #ptr " " #size)
static void ggml_cuda_pool_free_(void * ptr, size_t size, const char * call) {
scoped_spin_lock lock(g_cuda_pool_lock);
int id;
CUDA_CHECK(cudaGetDevice(&id));

#ifdef DEBUG_CUDA_MALLOC
printf("cuda pool[%d]: free %llu bytes at %llx [%s]\n", id, (unsigned long long) size, ptr, call);
#endif

g_cuda_pool_used[id] -= size;

// all deallocations must be in reverse order of the allocations
GGML_ASSERT(ptr == (void *) (g_cuda_pool_addr[id] + g_cuda_pool_used[id]));

GGML_UNUSED(call);
}

#endif

static bool g_cublas_loaded = false;

bool ggml_cublas_loaded(void) {
Expand Down Expand Up @@ -7437,13 +7560,13 @@ inline void ggml_cuda_op_mul_mat_cublas(

ggml_cuda_pool_free(dst_f16, dst_as);

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}

if (src1_as != 0) {
ggml_cuda_pool_free(src1_as_f16, src1_as);
}

if (src0_as != 0) {
ggml_cuda_pool_free(src0_as_f16, src0_as);
}
}
else {
float * src0_ddq_as_f32 = nullptr;
Expand Down Expand Up @@ -7800,14 +7923,14 @@ static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * s
CUDA_CHECK(cudaMemcpyAsync(dst->data, dst_ddf, ggml_nbytes(dst), cudaMemcpyDeviceToHost, main_stream));
}

if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
}
if (src1_asf > 0) {
ggml_cuda_pool_free(src1_ddf, src1_asf);
}
if (dst_asf > 0) {
ggml_cuda_pool_free(dst_ddf, dst_asf);
if (src0_asf > 0) {
ggml_cuda_pool_free(src0_ddf, src0_asf);
}

if (dst->backend == GGML_BACKEND_CPU) {
Expand Down Expand Up @@ -8119,17 +8242,17 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(ggml_cuda_set_device(id));

// free buffers again when done
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
}
if (src1_asq[id] > 0) {
ggml_cuda_pool_free(src1_ddq[id], src1_asq[id]);
}
if (dst_as[id] > 0) {
ggml_cuda_pool_free(dst_dd[id], dst_as[id]);
if (src1_asf[id] > 0) {
ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
}
if (src0_as[id] > 0) {
ggml_cuda_pool_free(src0_dd[id], src0_as[id]);
}
}

Expand Down Expand Up @@ -8497,12 +8620,12 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
cu_compute_type,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));

if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
if (ptrs_dst_s != 0) {
ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
}
if (ptrs_src_s != 0) {
ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
}
}
#endif

Expand Down