From 071dcd351b7e886a194cdd3903ffbddaf341c383 Mon Sep 17 00:00:00 2001
From: JohannesGaessler <johannesg@5d6.de>
Date: Tue, 23 May 2023 09:17:31 +0200
Subject: [PATCH 1/4] CUDA op template

---
 ggml-cuda.cu | 511 ++++++++++++++++++++++++++-------------------------
 ggml-cuda.h  |   1 -
 2 files changed, 262 insertions(+), 250 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index 98170a3ae17de..e6e83b05001ff 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -32,9 +32,23 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
         }                                                                               \
     } while (0)
 
+// Q = quantized, F = float, order is src0, src1, dst
+enum ggml_cuda_op_type {
+    GGML_CUDA_OP_TYPE_QQQ = 0,
+    GGML_CUDA_OP_TYPE_QQF = 1,
+    GGML_CUDA_OP_TYPE_QFQ = 2,
+    GGML_CUDA_OP_TYPE_QFF = 3,
+    GGML_CUDA_OP_TYPE_FQQ = 4,
+    GGML_CUDA_OP_TYPE_FQF = 5,
+    GGML_CUDA_OP_TYPE_FFQ = 6,
+    GGML_CUDA_OP_TYPE_FFF = 7,
+};
+
 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
 typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
-typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream);
+typedef void (*ggml_cuda_op_t)(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -360,25 +374,6 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
     }
 }
 
-static dequantize_mul_mat_vec_cuda_t ggml_get_dequantize_mul_mat_vec_cuda(ggml_type type) {
-    switch (type) {
-        case GGML_TYPE_Q4_0:
-            return dequantize_mul_mat_vec_q4_0_cuda;
-        case GGML_TYPE_Q4_1:
-            return dequantize_mul_mat_vec_q4_1_cuda;
-        case GGML_TYPE_Q5_0:
-            return dequantize_mul_mat_vec_q5_0_cuda;
-        case GGML_TYPE_Q5_1:
-            return dequantize_mul_mat_vec_q5_1_cuda;
-        case GGML_TYPE_Q8_0:
-            return dequantize_mul_mat_vec_q8_0_cuda;
-        case GGML_TYPE_F16:
-            return convert_mul_mat_vec_f16_cuda;
-        default:
-            return nullptr;
-    }
-}
-
 // buffer pool for cuda
 #define MAX_CUDA_BUFFERS 256
 
@@ -441,20 +436,24 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
 #define GGML_CUDA_MAX_EVENTS 64
 static cublasHandle_t g_cublasH = nullptr;
-static cudaStream_t g_cudaStreams[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaStream_t g_cudaStreams2[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents[GGML_CUDA_MAX_EVENTS] = { nullptr };
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaEvent_t g_cudaEvents_main[GGML_CUDA_MAX_EVENTS] = { nullptr };
+static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaStream_t g_cudaStreams_memcpy_dst[GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_EVENTS] = { nullptr };
 
 void ggml_init_cublas() {
     if (g_cublasH == nullptr) {
         // create streams
         for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking));
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking));
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[i], cudaStreamNonBlocking));
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[i], cudaStreamNonBlocking));
+            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_dst[i], cudaStreamNonBlocking));
         }
         // create events
         for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming));
+            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_main[i], cudaEventDisableTiming));
+            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[i], cudaEventDisableTiming));
         }
 
         // create cublas handle
@@ -514,125 +513,6 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
     }
 }
 
-static void ggml_cuda_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->backend == GGML_BACKEND_CUDA);
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[2];
-    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-    const int64_t ne12 = src1->ne[2];
-    const int64_t ne13 = src1->ne[3];
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-    size_t x_size, d_size;
-
-    float * d_X = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &x_size); // src0
-    float * d_Y = (float *) src1->data; // src1 is already on device, broadcasted.
-    float * d_D = (float *) ggml_cuda_pool_malloc(ne0 * sizeof(float), &d_size); // dst
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            const int i0 = i03*ne02 + i02;
-            float * c_X2 = d_X + i0*ne01*ne00;
-            float * c_D2 = d_D + i0*ne01*ne00;
-
-            cudaStream_t cudaStream = g_cudaStreams[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i0 % GGML_CUDA_MAX_EVENTS];
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X2, src0, i03, i02, cudaStream2));
-            CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
-
-            // wait for data
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
-
-            for (int64_t i01 = 0; i01 < ne01; i01++) {
-                const int64_t i13 = i03%ne13;
-                const int64_t i12 = i02%ne12;
-                const int64_t i11 = i01%ne11;
-                const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
-
-                float * c_X1 = c_X2 + i01*ne00;
-                float * c_Y = d_Y + i1*ne10;
-                float * c_D1 = c_D2 + i01*ne00;
-
-                // compute
-                mul_f32_cuda(c_X1, c_Y, c_D1, ne00, ne10, cudaStream);
-                CUDA_CHECK(cudaGetLastError());
-            }
-
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D2, sizeof(float)*ne00*ne01, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_D, d_size);
-}
-
-static void ggml_cuda_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-
-    size_t x_size, y_size, d_size;
-    float * d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-
-            float * c_X = d_X + i * x_ne;
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-
-            // copy data to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
-
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, ne00,
-                                c_Y, ne10,
-                        &beta,  c_D, ne01));
-
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
-
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-}
-
 static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
@@ -668,7 +548,7 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream = g_cudaStreams_main[i % GGML_CUDA_MAX_STREAMS];
 
             half  * c_X = d_X + i * x_ne;
             half  * c_Y = d_Y + i * y_ne;
@@ -726,7 +606,110 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
     ggml_cuda_pool_free(d_D, d_size);
 }
 
-static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+inline void ggml_cuda_op_mul(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    for (int64_t i01 = 0; i01 < ne01; i01++) {
+        const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
+
+        float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
+        float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
+        float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
+
+        // compute
+        mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
+        CUDA_CHECK(cudaGetLastError());
+    }
+
+    (void) dst;
+    (void) src0_ddq_i;
+}
+
+inline void ggml_cuda_op_dequantize_mul_mat_vec(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddq_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    switch (src0->type) {
+        case GGML_TYPE_Q4_0:
+            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        case GGML_TYPE_Q4_1:
+            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_0:
+            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        case GGML_TYPE_Q5_1:
+            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        case GGML_TYPE_Q8_0:
+            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        case GGML_TYPE_F16:
+            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            break;
+        default:
+            GGML_ASSERT(false);
+            break;
+    }
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src0_ddf_i;
+    (void) i1;
+}
+
+inline void ggml_cuda_op_mul_mat_cublas(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+
+    GGML_ASSERT(src0_ddf_i != nullptr);
+    GGML_ASSERT(src1_ddf_i != nullptr);
+    GGML_ASSERT(dst_ddf_i != nullptr);
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream_main));
+    CUBLAS_CHECK(
+        cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                &alpha, src0_ddf_i, ne00,
+                        src1_ddf_i, ne10,
+                &beta,  dst_ddf_i,  ne01));
+
+    (void) dst;
+    (void) src0_ddq_i;
+    (void) i1;
+}
+
+template<enum ggml_cuda_op_type op_type, ggml_cuda_op_t op>
+static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
@@ -734,107 +717,154 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int64_t ne0 = dst->ne[0];
+    const int64_t ne1 = dst->ne[1];
 
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
-    const ggml_type type = src0->type;
-    const bool mul_mat_vec = ne11 == 1;
 
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
+    const int64_t src0_stride = ne00 * ne01;
+    const int64_t src1_stride = ne10 * ne11;
+    const int64_t dst_stride = ne0 * ne1;
+    const int64_t num_iters = ne02 * ne03;
+
+    const size_t src0_ts = ggml_type_size(src0->type);
+
+    const bool src0_on_device = src0->backend == GGML_BACKEND_CUDA;
+    const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
+    const bool src0_needs_f32 = op_type & 0x4; // 3rd least significant bit = src0 needs f32
 
-    size_t x_size, y_size, d_size, q_size;
-    float * d_X = nullptr;
-    if (!mul_mat_vec) {
-        d_X = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size);
+    const bool src1_on_device = src1->backend == GGML_BACKEND_CUDA;
+
+    const bool dst_on_device = dst->backend == GGML_BACKEND_CUDA;
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
+
+    // dd = data device
+    char  * src0_ddq = nullptr; // quantized
+    float * src0_ddf = nullptr; // float
+    float * src1_ddf = nullptr;
+    float * dst_ddf = nullptr;
+
+    bool src0_ddq_malloced = false;
+    bool src0_ddf_malloced = false;
+    bool src1_ddf_malloced = false;
+    bool dst_ddf_malloced = false;
+
+    // asq = actual size quantized, asf = actual size float
+    size_t src0_asq, src0_asf, src1_asf, dst_asf;
+
+    if (src0_on_device) {
+        if (src0_is_f32) {
+            src0_ddf = (float *) src0->data;
+        } else {
+            src0_ddq = (char *) src0->data;
+        }
+    } else {
+        if (src0_is_f32) {
+            src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
+            src0_ddf_malloced = true;
+        } else {
+            src0_ddq = (char *) ggml_cuda_pool_malloc(num_iters * src0_stride * src0_ts, &src0_asq);
+            src0_ddq_malloced = true;
+        }
     }
-    float * d_Y = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-    char  * d_Q = (char  *) ggml_cuda_pool_malloc(n_mm * q_sz, &q_size);
 
-    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(type);
-    dequantize_mul_mat_vec_cuda_t dmmv = ggml_get_dequantize_mul_mat_vec_cuda(type);
-    GGML_ASSERT(to_fp32_cuda != nullptr);
+    if (src0_needs_f32 && !src0_is_f32) {
+        src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
+        src0_ddf_malloced = true;
+    }
+
+    if (src1_on_device) {
+        src1_ddf = (float *) src1->data;
+    } else {
+        src1_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src1_stride * sizeof(float), &src1_asf);
+        src1_ddf_malloced = true;
+    }
+    if (dst_on_device) {
+        dst_ddf = (float *) dst->data;
+    } else {
+        dst_ddf = (float *) ggml_cuda_pool_malloc(num_iters * dst_stride * sizeof(float), &dst_asf);
+        dst_ddf_malloced = true;
+    }
 
     for (int64_t i03 = 0; i03 < ne03; i03++) {
+        const int64_t i13 = i03 % ne13;
         for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams[i % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent = g_cudaEvents[i % GGML_CUDA_MAX_EVENTS];
+            const int64_t i12 = i02 % ne12;
 
-            float * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-            char  * c_Q = d_Q + i * q_sz;
+            const int64_t i0 = i03*ne02 + i02;
+            const int64_t i1 = i13*ne12 + i12;
 
-            // copy src0 to device if necessary
-            if (src0->backend == GGML_BACKEND_CPU) {
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2));
-            } else if (src0->backend == GGML_BACKEND_CUDA) {
-                c_Q = ((char *) src0->data) + i * q_sz;
-            } else {
-                GGML_ASSERT(false);
-            }
-            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
+            cudaStream_t cudaStream_main = g_cudaStreams_main[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaStream_t cudaStream_memcpy_dst = g_cudaStreams_memcpy_dst[i0 % GGML_CUDA_MAX_STREAMS];
+            cudaEvent_t  cudaEvent_main = g_cudaEvents_main[i0 % GGML_CUDA_MAX_EVENTS];
+            cudaEvent_t  cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[i0 % GGML_CUDA_MAX_EVENTS];
 
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
+            char  * src0_ddq_i = src0_ddq + i0*src0_stride;
+            float * src0_ddf_i = src0_ddf + i0*src0_stride;
+            float * src1_ddf_i = src1_ddf + i1*src1_stride;
+            float * dst_ddf_i = dst_ddf + i0*dst_stride;
 
-                // wait for data
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
+            // copy src0, src1 to device if necessary
+            if (!src1_on_device) { // src1 first to avoid blocking device queues
+                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf, src1, i03, i02, cudaStream_memcpy_src1));
+            }
+            CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
+            if (!src0_on_device) {
+                if (src0_is_f32) {
+                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf, src0, i03, i02, cudaStream_main));
+                } else {
+                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq, src0, i03, i02, cudaStream_main));
+                }
+            }
 
-                // compute
-                dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
+            if (src0_needs_f32 && !src0_is_f32) {
+                to_fp32_cuda(src0_ddq_i, src0_ddf_i, src0_stride, cudaStream_main);
                 CUDA_CHECK(cudaGetLastError());
+            }
 
-            } else { // general dequantization kernel + cuBLAS matrix matrix multiplication
-                float * c_X = d_X + i * x_ne;
+            // wait with main stream until src1 memcpy is done
+            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
 
-                // convert src0 to fp32 on device
-                to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2);
-                CUDA_CHECK(cudaGetLastError());
-                CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2));
-
-                // copy src1 to device
-                CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream));
-
-                // wait for conversion
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
-
-                // compute
-                CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-                CUBLAS_CHECK(
-                    cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                            ne01, ne11, ne10,
-                            &alpha, c_X, ne00,
-                                    c_Y, ne10,
-                            &beta,  c_D, ne01));
-            }
+            // do the computation
+            op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i1, cudaStream_main);
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+            CUDA_CHECK(cudaEventRecord(cudaEvent_main, cudaStream_main));
+
+            // copy dst to host if necessary
+            if (!dst_on_device) {
+                // wait with memcpy until main stream is done
+                CUDA_CHECK(cudaStreamWaitEvent(cudaStream_memcpy_dst, cudaEvent_main, 0));
+
+                float * dhf_dst_i = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+                CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_memcpy_dst));
+            }
         }
     }
 
     CUDA_CHECK(cudaDeviceSynchronize());
-    if (!mul_mat_vec) {
-        ggml_cuda_pool_free(d_X, x_size);
+    if (src0_ddf_malloced) {
+        ggml_cuda_pool_free(src0_ddf, src0_asf);
+    }
+    if (src0_ddq_malloced) {
+        ggml_cuda_pool_free(src0_ddq, src0_asq);
+    }
+    if (src1_ddf_malloced) {
+        ggml_cuda_pool_free(src1_ddf, src1_asf);
+    }
+    if (dst_ddf_malloced) {
+        ggml_cuda_pool_free(dst_ddf, dst_asf);
     }
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-    ggml_cuda_pool_free(d_Q, q_size);
 }
 
 void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_mul_f32(src0, src1, dst);
+    ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul>(src0, src1, dst);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
@@ -873,18 +903,27 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
     GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
 
     if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_mul_mat_f32(src0, src1, dst);
+        ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
     }
     else if (src0->type == GGML_TYPE_F16) {
         if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
+            // ggml_cuda_op<GGML_CUDA_OP_TYPE_QQF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
             ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
         }
         else {
-            ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+            if (src1->ne[1] == 1) {
+                ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
+            } else {
+                ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
+            }
         }
     }
     else if (ggml_is_quantized(src0->type)) {
-        ggml_cuda_mul_mat_q_f32(src0, src1, dst);
+        if (src1->ne[1] == 1) {
+            ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
+        } else {
+            ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
+        }
     }
     else {
         GGML_ASSERT(false);
@@ -900,32 +939,6 @@ size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct
     }
 }
 
-void ggml_cuda_transform_tensor(ggml_tensor * tensor) {
-    const int64_t ne0 = tensor->ne[0];
-    const int64_t ne1 = tensor->ne[1];
-    const int64_t ne2 = tensor->ne[2];
-    const int64_t ne3 = tensor->ne[3];
-
-    const ggml_type type = tensor->type;
-    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
-
-    size_t q_size;
-    char * dst = (char *) ggml_cuda_pool_malloc(q_sz, &q_size);
-
-    cudaStream_t cudaStream2 = g_cudaStreams2[0];
-
-    // copy tensor to device
-    for (int64_t i3 = 0; i3 < ne3; i3++) {
-        for (int64_t i2 = 0; i2 < ne2; i2++) {
-            int i = i3*ne2 + i2;
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(dst + i*ne0*ne1, tensor, i3, i2, cudaStream2));
-        }
-    }
-
-    tensor->data = dst;
-    tensor->backend = GGML_BACKEND_CUDA;
-}
-
 void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
     FILE * fp = fopen(fname, "rb");
 
diff --git a/ggml-cuda.h b/ggml-cuda.h
index 6a04dde6c37a9..2ef636593400b 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -15,7 +15,6 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void ggml_cuda_transform_tensor(struct ggml_tensor * tensor);
 void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
 
 #ifdef  __cplusplus

From 971920e93580b7ae3e979f6d577596301c629349 Mon Sep 17 00:00:00 2001
From: JohannesGaessler <johannesg@5d6.de>
Date: Wed, 24 May 2023 12:55:50 +0200
Subject: [PATCH 2/4] ggml_cuda_compute_forward

---
 ggml-cuda.cu | 35 +++++++++++++++++++++++++++++
 ggml-cuda.h  |  1 +
 ggml.c       | 63 +++++++++-------------------------------------------
 ggml.h       | 18 +++++++++++++++
 4 files changed, 65 insertions(+), 52 deletions(-)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index e6e83b05001ff..cec2da5e7df06 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -862,6 +862,10 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     }
 }
 
+bool ggml_cuda_can_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    return src1->backend == GGML_BACKEND_CUDA;
+}
+
 void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
     ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul>(src0, src1, dst);
@@ -968,3 +972,34 @@ void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const
     free(buf_host);
     fclose(fp);
 }
+
+bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
+    switch (tensor->op) {
+        case GGML_OP_MUL:
+            if (!ggml_cuda_can_mul(tensor->src0, tensor->src1, tensor)) {
+                return false;
+            }
+            if (params->ith != 0) {
+                return true;
+            }
+            if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+                return true;
+            }
+            ggml_cuda_mul(tensor->src0, tensor->src1, tensor);
+            return true;
+        case GGML_OP_MUL_MAT:
+            if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
+                return false;
+            }
+            if (params->ith != 0) {
+                return true;
+            }
+            if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+                return true;
+            }
+            ggml_cuda_mul_mat(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
+            return true;
+        default:
+            return false;
+    }
+}
diff --git a/ggml-cuda.h b/ggml-cuda.h
index 2ef636593400b..f71701ce88371 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -16,6 +16,7 @@ void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
 void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
+bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
 }
diff --git a/ggml.c b/ggml.c
index 00bbee503f52a..4a05df1e3e7bf 100644
--- a/ggml.c
+++ b/ggml.c
@@ -3647,26 +3647,6 @@ struct ggml_context_container {
     struct ggml_context context;
 };
 
-//
-// compute types
-//
-
-enum ggml_task_type {
-    GGML_TASK_INIT = 0,
-    GGML_TASK_COMPUTE,
-    GGML_TASK_FINALIZE,
-};
-
-struct ggml_compute_params {
-    enum ggml_task_type type;
-
-    int ith, nth;
-
-    // work buffer for all threads
-    size_t wsize;
-    void * wdata;
-};
-
 //
 // ggml state
 //
@@ -8166,14 +8146,7 @@ static void ggml_compute_forward_mul_f32(
     const int ith = params->ith;
     const int nth = params->nth;
 
-#ifdef GGML_USE_CUBLAS
-    if (src1->backend == GGML_BACKEND_CUDA) {
-        if (ith == 0) {
-            ggml_cuda_mul(src0, src1, dst);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#ifdef GGML_USE_CLBLAST
     if (src1->backend == GGML_BACKEND_CL) {
         if (ith == 0) {
             ggml_cl_mul(src0, src1, dst);
@@ -9614,14 +9587,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9786,14 +9752,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -9998,14 +9957,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_CUBLAS)
-    if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
-        if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
-            ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
-        }
-        return;
-    }
-#elif defined(GGML_USE_CLBLAST)
+#if defined(GGML_USE_CLBLAST)
     if (ggml_cl_can_mul_mat(src0, src1, dst)) {
         if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
             ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
@@ -12931,6 +12883,13 @@ static void ggml_compute_forward_map_binary(
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
+#ifdef GGML_USE_CUBLAS
+    bool used_cuda = ggml_cuda_compute_forward(params, tensor);
+    if (used_cuda) {
+        return;
+    }
+#endif // GGML_USE_CUBLAS
+
     switch (tensor->op) {
         case GGML_OP_DUP:
             {
diff --git a/ggml.h b/ggml.h
index 2ea87ce9a9749..5dbe0f2ffb78f 100644
--- a/ggml.h
+++ b/ggml.h
@@ -413,6 +413,24 @@ extern "C" {
         bool   no_alloc;   // don't allocate memory for the tensor data
     };
 
+
+    // compute types
+    enum ggml_task_type {
+        GGML_TASK_INIT = 0,
+        GGML_TASK_COMPUTE,
+        GGML_TASK_FINALIZE,
+    };
+
+    struct ggml_compute_params {
+        enum ggml_task_type type;
+
+        int ith, nth;
+
+        // work buffer for all threads
+        size_t wsize;
+        void * wdata;
+    };
+
     // misc
 
     GGML_API void    ggml_time_init(void); // call this once at the beginning of the program

From 4f9640b8fe3d8c0b1c2bf89cd32f2a8a7a8184b4 Mon Sep 17 00:00:00 2001
From: JohannesGaessler <johannesg@5d6.de>
Date: Wed, 24 May 2023 14:29:21 +0200
Subject: [PATCH 3/4] Tensor parallelism

---
 examples/common.cpp        |  28 ++
 examples/common.h          |  15 +-
 examples/server/server.cpp |  33 ++
 ggml-cuda.cu               | 756 ++++++++++++++++++++-----------------
 ggml-cuda.h                |  14 +-
 ggml-opencl.cpp            |  20 +-
 ggml.c                     |  14 +-
 ggml.h                     |  16 +-
 llama.cpp                  |  85 +++--
 llama.h                    |  14 +-
 10 files changed, 591 insertions(+), 404 deletions(-)

diff --git a/examples/common.cpp b/examples/common.cpp
index b5810f28f4901..4bdb2be9562d9 100644
--- a/examples/common.cpp
+++ b/examples/common.cpp
@@ -9,6 +9,7 @@
 #include <algorithm>
 #include <sstream>
 #include <unordered_set>
+#include <regex>
 
 #if defined(__APPLE__) && defined(__MACH__)
 #include <sys/types.h>
@@ -295,6 +296,30 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
             fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
             fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
 #endif
+        } else if (arg == "--tensor-split" || arg == "-ts") {
+            if (++i >= argc) {
+                invalid_param = true;
+                break;
+            }
+#ifdef GGML_USE_CUBLAS
+            std::string arg_next = argv[i];
+
+            // split string by , and /
+            const std::regex regex{R"([,/]+)"};
+            std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
+            std::vector<std::string> split_arg{it, {}};
+            GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+
+            for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i) {
+                if (i < split_arg.size()) {
+                    params.tensor_split[i] = std::stof(split_arg[i]);
+                } else {
+                    params.tensor_split[i] = 0.0f;
+                }
+            }
+#else
+      fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+#endif // GGML_USE_CUBLAS
         } else if (arg == "--no-mmap") {
             params.use_mmap = false;
         } else if (arg == "--mtest") {
@@ -438,6 +463,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
 #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
     fprintf(stderr, "  -ngl N, --n-gpu-layers N\n");
     fprintf(stderr, "                        number of layers to store in VRAM\n");
+    fprintf(stderr, "  -ts SPLIT --tensor-split SPLIT\n");
+    fprintf(stderr, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
 #endif
     fprintf(stderr, "  --mtest               compute maximum memory usage\n");
     fprintf(stderr, "  --export              export the computation graph to 'llama.ggml'\n");
@@ -484,6 +511,7 @@ struct llama_context * llama_init_from_gpt_params(const gpt_params & params) {
 
     lparams.n_ctx        = params.n_ctx;
     lparams.n_gpu_layers = params.n_gpu_layers;
+    memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float));
     lparams.seed         = params.seed;
     lparams.f16_kv       = params.memory_f16;
     lparams.use_mmap     = params.use_mmap;
diff --git a/examples/common.h b/examples/common.h
index 66bdeb5e9287d..0518584c1734c 100644
--- a/examples/common.h
+++ b/examples/common.h
@@ -21,13 +21,14 @@
 int32_t get_num_physical_cores();
 
 struct gpt_params {
-    int32_t seed          = -1;  // RNG seed
-    int32_t n_threads     = get_num_physical_cores();
-    int32_t n_predict     = -1;  // new tokens to predict
-    int32_t n_ctx         = 512; // context size
-    int32_t n_batch       = 512; // batch size for prompt processing (must be >=32 to use BLAS)
-    int32_t n_keep        = 0;   // number of tokens to keep from initial prompt
-    int32_t n_gpu_layers  = 0;   // number of layers to store in VRAM
+    int32_t seed                           = -1;  // RNG seed
+    int32_t n_threads                      = get_num_physical_cores();
+    int32_t n_predict                      = -1;  // new tokens to predict
+    int32_t n_ctx                          = 512; // context size
+    int32_t n_batch                        = 512; // batch size for prompt processing (must be >=32 to use BLAS)
+    int32_t n_keep                         = 0;   // number of tokens to keep from initial prompt
+    int32_t n_gpu_layers                   = 0;   // number of layers to store in VRAM
+    float   tensor_split[LLAMA_MAX_DEVICES] = {0}; // how split tensors should be distributed across GPUs
 
     // sampling parameters
     std::unordered_map<llama_token, float> logit_bias; // logit bias for specific tokens
diff --git a/examples/server/server.cpp b/examples/server/server.cpp
index 9aa7db255aab3..0ea7d859967aa 100644
--- a/examples/server/server.cpp
+++ b/examples/server/server.cpp
@@ -401,6 +401,8 @@ void server_print_usage(int /*argc*/, char **argv, const gpt_params &params)
 #ifdef LLAMA_SUPPORTS_GPU_OFFLOAD
   fprintf(stderr, "  -ngl N, --n-gpu-layers N\n");
   fprintf(stderr, "                        number of layers to store in VRAM\n");
+  fprintf(stderr, "  -ts SPLIT --tensor-split SPLIT\n");
+  fprintf(stderr, "                        how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
 #endif
   fprintf(stderr, "  -m FNAME, --model FNAME\n");
   fprintf(stderr, "                        model path (default: %s)\n", params.model.c_str());
@@ -503,6 +505,37 @@ bool server_params_parse(int argc, char **argv, server_params &sparams, gpt_para
       fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
       fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
 #endif
+    }
+    else if (arg == "--tensor-split" || arg == "-ts")
+    {
+      if (++i >= argc)
+      {
+        invalid_param = true;
+        break;
+      }
+#ifdef GGML_USE_CUBLAS
+      std::string arg_next = argv[i];
+
+      // split string by , and /
+      const std::regex regex{R"([,/]+)"};
+      std::sregex_token_iterator it{arg_next.begin(), arg_next.end(), regex, -1};
+      std::vector<std::string> split_arg{it, {}};
+      GGML_ASSERT(split_arg.size() <= LLAMA_MAX_DEVICES);
+
+      for (size_t i = 0; i < LLAMA_MAX_DEVICES; ++i)
+      {
+        if (i < split_arg.size())
+        {
+          params.tensor_split[i] = std::stof(split_arg[i]);
+        }
+        else
+        {
+          params.tensor_split[i] = 0.0f;
+        }
+      }
+#else
+      fprintf(stderr, "WARNING: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
+#endif // GGML_USE_CUBLAS
     }
     else
     {
diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index cec2da5e7df06..bccf74664ee1f 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -23,32 +23,33 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
         }                                                                               \
     } while (0)
 
+#if CUDART_VERSION >= 12
 #define CUBLAS_CHECK(err)                                                               \
     do {                                                                                \
         cublasStatus_t err_ = (err);                                                    \
         if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
-            fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);    \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n",                         \
+                    err_, __FILE__, __LINE__, cublasGetStatusString(err_));             \
             exit(1);                                                                    \
         }                                                                               \
     } while (0)
-
-// Q = quantized, F = float, order is src0, src1, dst
-enum ggml_cuda_op_type {
-    GGML_CUDA_OP_TYPE_QQQ = 0,
-    GGML_CUDA_OP_TYPE_QQF = 1,
-    GGML_CUDA_OP_TYPE_QFQ = 2,
-    GGML_CUDA_OP_TYPE_QFF = 3,
-    GGML_CUDA_OP_TYPE_FQQ = 4,
-    GGML_CUDA_OP_TYPE_FQF = 5,
-    GGML_CUDA_OP_TYPE_FFQ = 6,
-    GGML_CUDA_OP_TYPE_FFF = 7,
-};
+#else
+#define CUBLAS_CHECK(err)                                                               \
+    do {                                                                                \
+        cublasStatus_t err_ = (err);                                                    \
+        if (err_ != CUBLAS_STATUS_SUCCESS) {                                            \
+            fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__);  \
+            exit(1);                                                                    \
+        }                                                                               \
+    } while (0)
+#endif // CUDART_VERSION >= 11
 
 typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
 typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
+typedef void (*ggml_cuda_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
 typedef void (*ggml_cuda_op_t)(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main);
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i0_low, int64_t i0_high, int i1, cudaStream_t & cudaStream_main);
 
 // QK = number of values after dequantization
 // QR = QK / number of values before dequantization
@@ -201,8 +202,8 @@ static __device__ void dequantize_q8_0(const void * vx, const int ib, const int
 static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
     const half * x = (const half *) vx;
 
-    v0 = __half2float(x[ib + 0]);
-    v1 = __half2float(x[ib + 1]);
+    v0 = __half2float(x[ib + iqs + 0]);
+    v1 = __half2float(x[ib + iqs + 1]);
 }
 
 template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@@ -344,7 +345,7 @@ static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, f
 
 static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
     const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
-    dequantize_block<32, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
+    dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
 }
 
 static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
@@ -396,14 +397,16 @@ struct cuda_buffer {
     size_t size = 0;
 };
 
-static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS];
+static cuda_buffer g_cuda_buffer_pool[GGML_CUDA_MAX_DEVICES][MAX_CUDA_BUFFERS];
 static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT;
 
 static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
         if (b.size >= size && b.ptr != nullptr) {
             void * ptr = b.ptr;
             *actual_size = b.size;
@@ -420,9 +423,11 @@ static void * ggml_cuda_pool_malloc(size_t size, size_t * actual_size) {
 
 static void ggml_cuda_pool_free(void * ptr, size_t size) {
     scoped_spin_lock lock(g_cuda_pool_lock);
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
 
     for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) {
-        cuda_buffer& b = g_cuda_buffer_pool[i];
+        cuda_buffer& b = g_cuda_buffer_pool[id][i];
         if (b.ptr == nullptr) {
             b.ptr = ptr;
             b.size = size;
@@ -435,33 +440,79 @@ static void ggml_cuda_pool_free(void * ptr, size_t size) {
 
 #define GGML_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication.
 #define GGML_CUDA_MAX_EVENTS 64
-static cublasHandle_t g_cublasH = nullptr;
-static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents_main[GGML_CUDA_MAX_EVENTS] = { nullptr };
-static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaStream_t g_cudaStreams_memcpy_dst[GGML_CUDA_MAX_STREAMS] = { nullptr };
-static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_EVENTS] = { nullptr };
+
+static int g_device_count = -1;
+static float g_tensor_split[GGML_CUDA_MAX_DEVICES] = {0};
+
+static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
+
+static cudaStream_t g_cudaStreams_main[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
+
+static cudaStream_t g_cudaStreams_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { nullptr };
+static cudaEvent_t g_cudaEvents_memcpy_src1[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_EVENTS] = { nullptr };
 
 void ggml_init_cublas() {
-    if (g_cublasH == nullptr) {
-        // create streams
-        for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[i], cudaStreamNonBlocking));
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[i], cudaStreamNonBlocking));
-            CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_dst[i], cudaStreamNonBlocking));
+    static bool initialized = false;
+
+    if (!initialized) {
+        CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
+        GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
+        int64_t total_vram = 0;
+        fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count);
+        for (int i = 0; i < g_device_count; ++i) {
+            cudaDeviceProp prop;
+            CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
+            fprintf(stderr, "  %d. %s\n", i+1, prop.name);
+            g_tensor_split[i] = total_vram;
+            total_vram += prop.totalGlobalMem;
         }
-        // create events
-        for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
-            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_main[i], cudaEventDisableTiming));
-            CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[i], cudaEventDisableTiming));
+        for (int i = 0; i < g_device_count; ++i) {
+            g_tensor_split[i] /= total_vram;
         }
 
-        // create cublas handle
-        CUBLAS_CHECK(cublasCreate(&g_cublasH));
-        CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH));
+        for (int id = 0; id < g_device_count; ++id) {
+            CUDA_CHECK(cudaSetDevice(id));
+
+            // create streams
+            for (int i = 0; i < GGML_CUDA_MAX_STREAMS; ++i) {
+                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_main[id][i], cudaStreamNonBlocking));
+                CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams_memcpy_src1[id][i], cudaStreamNonBlocking));
+            }
+            // create events
+            for (int i = 0; i < GGML_CUDA_MAX_EVENTS; ++i) {
+                CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents_memcpy_src1[id][i], cudaEventDisableTiming));
+            }
+
+            // create cublas handle
+            CUBLAS_CHECK(cublasCreate(&g_cublas_handles[id]));
+            CUBLAS_CHECK(cublasSetMathMode(g_cublas_handles[id], CUBLAS_TF32_TENSOR_OP_MATH));
+        }
 
         // configure logging to stdout
         // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr));
+
+        initialized = true;
+    }
+}
+
+void ggml_cuda_set_tensor_split(float * tensor_split) {
+    bool all_zero = true;
+    for (int i = 0; i < g_device_count; ++i) {
+        if (tensor_split[i] != 0.0f) {
+            all_zero = false;
+            break;
+        }
+    }
+    if (all_zero) {
+        return;
+    }
+    float split_sum = 0.0f;
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] = split_sum;
+        split_sum += tensor_split[i];
+    }
+    for (int i = 0; i < g_device_count; ++i) {
+        g_tensor_split[i] /= split_sum;
     }
 }
 
@@ -485,26 +536,29 @@ void ggml_cuda_host_free(void * ptr) {
     CUDA_CHECK(cudaFreeHost(ptr));
 }
 
-static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) {
-    const uint64_t ne0 = src->ne[0];
-    const uint64_t ne1 = src->ne[1];
-    const uint64_t nb0 = src->nb[0];
-    const uint64_t nb1 = src->nb[1];
-    const uint64_t nb2 = src->nb[2];
-    const uint64_t nb3 = src->nb[3];
+static cudaError_t ggml_cuda_h2d_tensor_2d(
+    void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) {
+
+    char * dst_char = (char *) dst;
+    const int64_t ne0 = src->ne[0];
+    const int64_t nb0 = src->nb[0];
+    const int64_t nb1 = src->nb[1];
+    const int64_t nb2 = src->nb[2];
+    const int64_t nb3 = src->nb[3];
     const enum ggml_type type = src->type;
-    const size_t ts = ggml_type_size(type);
-    const size_t bs = ggml_blck_size(type);
+    const int64_t ts = ggml_type_size(type);
+    const int64_t bs = ggml_blck_size(type);
+    int64_t i1_diff = i1_high - i1_low;
 
-    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
+    const void * x = (const void *) ((const char *) src->data + i1_low*nb1 + i2*nb2 + i3*nb3);
     if (nb0 == ts && nb1 == ts*ne0/bs) {
-        return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream);
+        return cudaMemcpyAsync(dst_char, x, i1_diff*nb1, cudaMemcpyHostToDevice, stream);
     } else if (nb0 == ts) {
-        return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream);
+        return cudaMemcpy2DAsync(dst_char, ts*ne0/bs, x, nb1, ts*ne0/bs, i1_diff, cudaMemcpyHostToDevice, stream);
     } else {
-        for (uint64_t i1 = 0; i1 < ne1; i1++) {
+        for (int64_t i1 = 0; i1 < i1_diff; i1++) {
             const void * rx = (const void *) ((const char *) x + i1*nb1);
-            void * rd = (void *) ((char *) dst + i1*ts*ne0/bs);
+            void * rd = (void *) (dst_char + i1*ts*ne0/bs);
             // pretend the row is a matrix with cols=1
             cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream);
             if (r != cudaSuccess) return r;
@@ -513,114 +567,21 @@ static cudaError_t ggml_cuda_h2d_tensor_2d(void * dst, const struct ggml_tensor
     }
 }
 
-static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t ne03 = src0->ne[3];
-
-    const int64_t ne10 = src1->ne[0];
-    const int64_t ne11 = src1->ne[1];
-
-    const int nb10 = src1->nb[0];
-    const int nb11 = src1->nb[1];
-    const int nb12 = src1->nb[2];
-    const int nb13 = src1->nb[3];
-
-    const int nb2  = dst->nb[2];
-    const int nb3  = dst->nb[3];
-
-    const float alpha = 1.0f;
-    const float beta = 0.0f;
-    const int x_ne = ne01 * ne00;
-    const int y_ne = ne11 * ne10;
-    const int d_ne = ne11 * ne01;
-    const int n_mm = ne03 * ne02;
-
-    size_t x_size, y_size, d_size;
-    half  * d_X =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
-    half  * d_Y =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
-    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
-
-    bool src1_cont_rows = nb10 == sizeof(float);
-    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            int i = i03*ne02 + i02;
-            cudaStream_t cudaStream = g_cudaStreams_main[i % GGML_CUDA_MAX_STREAMS];
-
-            half  * c_X = d_X + i * x_ne;
-            half  * c_Y = d_Y + i * y_ne;
-            float * c_D = d_D + i * d_ne;
-
-            // copy src0 to device
-            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream));
-
-            // convert src1 to fp16
-            // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
-            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
-            if (src1_cont_rows) {
-                if (src1_cont_cols) {
-                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
-                }
-                else {
-                    for (int64_t i01 = 0; i01 < ne11; i01++) {
-                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
-                    }
-                }
-            }
-            else {
-                for (int64_t i01 = 0; i01 < ne11; i01++) {
-                    for (int64_t i00 = 0; i00 < ne10; i00++) {
-                        // very slow due to no inlining
-                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
-                    }
-                }
-            }
-
-            // copy src1 to device
-            CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
-
-            // compute
-            CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream));
-            CUBLAS_CHECK(
-                cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                        ne01, ne11, ne10,
-                        &alpha, c_X, CUDA_R_16F, ne00,
-                                c_Y, CUDA_R_16F, ne10,
-                        &beta,  c_D, CUDA_R_32F, ne01,
-                        CUBLAS_COMPUTE_32F_FAST_16F,
-                        CUBLAS_GEMM_DEFAULT));
-
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
-        }
-    }
-
-    CUDA_CHECK(cudaDeviceSynchronize());
-    ggml_cuda_pool_free(d_X, x_size);
-    ggml_cuda_pool_free(d_Y, y_size);
-    ggml_cuda_pool_free(d_D, d_size);
-}
-
 inline void ggml_cuda_op_mul(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i0_low, int64_t i0_high, int i1,
+    cudaStream_t & cudaStream_main){
 
     GGML_ASSERT(src0_ddf_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    for (int64_t i01 = 0; i01 < ne01; i01++) {
+    for (int64_t i01 = i0_low; i01 < i0_high; i01++) {
         const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
 
         float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
@@ -638,33 +599,34 @@ inline void ggml_cuda_op_mul(
 
 inline void ggml_cuda_op_dequantize_mul_mat_vec(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i0_low, int64_t i0_high, int i1,
+    cudaStream_t & cudaStream_main){
 
     GGML_ASSERT(src0_ddq_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
     GGML_ASSERT(dst_ddf_i != nullptr);
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
+    const int64_t nrows = i0_high - i0_low;
 
     switch (src0->type) {
         case GGML_TYPE_Q4_0:
-            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         case GGML_TYPE_Q4_1:
-            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         case GGML_TYPE_Q5_0:
-            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         case GGML_TYPE_Q5_1:
-            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         case GGML_TYPE_Q8_0:
-            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         case GGML_TYPE_F16:
-            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, ne01, cudaStream_main);
+            convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
             break;
         default:
             GGML_ASSERT(false);
@@ -680,7 +642,8 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
 
 inline void ggml_cuda_op_mul_mat_cublas(
     const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
-    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int i1, cudaStream_t & cudaStream_main){
+    float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i0_low, int64_t i0_high, int i1,
+    cudaStream_t & cudaStream_main){
 
     GGML_ASSERT(src0_ddf_i != nullptr);
     GGML_ASSERT(src1_ddf_i != nullptr);
@@ -690,30 +653,35 @@ inline void ggml_cuda_op_mul_mat_cublas(
     const float beta = 0.0f;
 
     const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
 
-    CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream_main));
+    const int64_t i0_diff = i0_high - i0_low;
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], cudaStream_main));
     CUBLAS_CHECK(
-        cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N,
-                ne01, ne11, ne10,
+        cublasSgemm(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                i0_diff, ne11, ne10,
                 &alpha, src0_ddf_i, ne00,
                         src1_ddf_i, ne10,
-                &beta,  dst_ddf_i,  ne01));
+                &beta,  dst_ddf_i,  i0_diff));
 
     (void) dst;
     (void) src0_ddq_i;
     (void) i1;
 }
 
-template<enum ggml_cuda_op_type op_type, ggml_cuda_op_t op>
-static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+                         ggml_cuda_op_t op, bool src0_needs_f32) {
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
+    const int64_t nrows0 = ggml_nrows(src0);
 
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
@@ -726,152 +694,230 @@ static void ggml_cuda_op(const ggml_tensor * src0, const ggml_tensor * src1, ggm
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
 
+    // strides for iteration over dims 3 and 2
     const int64_t src0_stride = ne00 * ne01;
     const int64_t src1_stride = ne10 * ne11;
     const int64_t dst_stride = ne0 * ne1;
     const int64_t num_iters = ne02 * ne03;
 
     const size_t src0_ts = ggml_type_size(src0->type);
+    const size_t src0_bs = ggml_blck_size(src0->type);
+
+    struct ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    struct ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
 
-    const bool src0_on_device = src0->backend == GGML_BACKEND_CUDA;
+    // indices of the devices on which the input data is stored
+    int src0_id = src0_extra == nullptr ? -1 : src0_extra->i_device;
+    int src1_id = src1_extra == nullptr ? -1 : src1_extra->i_device;
+
+    const bool src0_on_device = src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT;
     const bool src0_is_f32 = src0->type == GGML_TYPE_F32;
-    const bool src0_needs_f32 = op_type & 0x4; // 3rd least significant bit = src0 needs f32
 
-    const bool src1_on_device = src1->backend == GGML_BACKEND_CUDA;
+    const bool src1_on_device = src1->backend == GGML_BACKEND_GPU || src1->backend == GGML_BACKEND_GPU_SPLIT;
 
-    const bool dst_on_device = dst->backend == GGML_BACKEND_CUDA;
+    const bool dst_on_device = dst->backend == GGML_BACKEND_GPU || dst->backend == GGML_BACKEND_GPU_SPLIT;
 
     const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(src0->type);
 
     // dd = data device
-    char  * src0_ddq = nullptr; // quantized
-    float * src0_ddf = nullptr; // float
-    float * src1_ddf = nullptr;
-    float * dst_ddf = nullptr;
-
-    bool src0_ddq_malloced = false;
-    bool src0_ddf_malloced = false;
-    bool src1_ddf_malloced = false;
-    bool dst_ddf_malloced = false;
+    char  * src0_ddq[GGML_CUDA_MAX_DEVICES] = {nullptr}; // quantized
+    float * src0_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr}; // float
+    float * src1_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
+    float * dst_ddf[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
     // asq = actual size quantized, asf = actual size float
-    size_t src0_asq, src0_asf, src1_asf, dst_asf;
+    size_t src0_asq[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src0_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t src1_asf[GGML_CUDA_MAX_DEVICES] = {0};
+    size_t dst_asf[GGML_CUDA_MAX_DEVICES] = {0};
+
+    for (int id = 0; id < g_device_count; ++id) {
+        // if data is on one device (!= -1) but not this one, continue
+        if (src0_id != -1 && src0_id != id) {
+            continue;
+        }
+        if (src1_id != -1 && src1_id != id) {
+            continue;
+        }
 
-    if (src0_on_device) {
-        if (src0_is_f32) {
-            src0_ddf = (float *) src0->data;
+        bool split = src0_id == -1 && src1_id == -1;
+        int64_t row_low, row_high;
+        if (split) {
+            row_low = id == 0 ? 0 : nrows0*g_tensor_split[id];
+            row_low -= row_low % GGML_CUDA_DMMV_Y;
+            row_high = id == g_device_count - 1 ? nrows0 : nrows0*g_tensor_split[id + 1];
+            row_high -= row_high % GGML_CUDA_DMMV_Y;
         } else {
-            src0_ddq = (char *) src0->data;
+            row_low = 0;
+            row_high = ne01;
         }
-    } else {
-        if (src0_is_f32) {
-            src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
-            src0_ddf_malloced = true;
-        } else {
-            src0_ddq = (char *) ggml_cuda_pool_malloc(num_iters * src0_stride * src0_ts, &src0_asq);
-            src0_ddq_malloced = true;
+        if (row_low == row_high) {
+            continue;
         }
-    }
-
-    if (src0_needs_f32 && !src0_is_f32) {
-        src0_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src0_stride * sizeof(float), &src0_asf);
-        src0_ddf_malloced = true;
-    }
-
-    if (src1_on_device) {
-        src1_ddf = (float *) src1->data;
-    } else {
-        src1_ddf = (float *) ggml_cuda_pool_malloc(num_iters * src1_stride * sizeof(float), &src1_asf);
-        src1_ddf_malloced = true;
-    }
-    if (dst_on_device) {
-        dst_ddf = (float *) dst->data;
-    } else {
-        dst_ddf = (float *) ggml_cuda_pool_malloc(num_iters * dst_stride * sizeof(float), &dst_asf);
-        dst_ddf_malloced = true;
-    }
-
-    for (int64_t i03 = 0; i03 < ne03; i03++) {
-        const int64_t i13 = i03 % ne13;
-        for (int64_t i02 = 0; i02 < ne02; i02++) {
-            const int64_t i12 = i02 % ne12;
 
-            const int64_t i0 = i03*ne02 + i02;
-            const int64_t i1 = i13*ne12 + i12;
+        int64_t row_diff = row_high - row_low;
 
-            cudaStream_t cudaStream_main = g_cudaStreams_main[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaStream_t cudaStream_memcpy_dst = g_cudaStreams_memcpy_dst[i0 % GGML_CUDA_MAX_STREAMS];
-            cudaEvent_t  cudaEvent_main = g_cudaEvents_main[i0 % GGML_CUDA_MAX_EVENTS];
-            cudaEvent_t  cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[i0 % GGML_CUDA_MAX_EVENTS];
+        cudaSetDevice(id);
 
-            char  * src0_ddq_i = src0_ddq + i0*src0_stride;
-            float * src0_ddf_i = src0_ddf + i0*src0_stride;
-            float * src1_ddf_i = src1_ddf + i1*src1_stride;
-            float * dst_ddf_i = dst_ddf + i0*dst_stride;
-
-            // copy src0, src1 to device if necessary
-            if (!src1_on_device) { // src1 first to avoid blocking device queues
-                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf, src1, i03, i02, cudaStream_memcpy_src1));
+        if (src0_on_device) {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) src0_extra->data_device[id];
+            } else {
+                src0_ddq[id] = (char *) src0_extra->data_device[id];
             }
-            CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
-            if (!src0_on_device) {
-                if (src0_is_f32) {
-                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf, src0, i03, i02, cudaStream_main));
-                } else {
-                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq, src0, i03, i02, cudaStream_main));
-                }
+        } else {
+            if (src0_is_f32) {
+                src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+            } else {
+                src0_ddq[id] = (char *) ggml_cuda_pool_malloc(row_diff*ne00 * src0_ts/src0_bs, &src0_asq[id]);
             }
+        }
 
-            if (src0_needs_f32 && !src0_is_f32) {
-                to_fp32_cuda(src0_ddq_i, src0_ddf_i, src0_stride, cudaStream_main);
-                CUDA_CHECK(cudaGetLastError());
-            }
+        if (src0_needs_f32 && !src0_is_f32) {
+            src0_ddf[id] = (float *) ggml_cuda_pool_malloc(row_diff*ne00 * sizeof(float), &src0_asf[id]);
+        }
 
-            // wait with main stream until src1 memcpy is done
-            CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
+        if (src1_on_device) {
+            src1_ddf[id] = (float *) src1_extra->data_device[id];
+        } else {
+            src1_ddf[id] = (float *) ggml_cuda_pool_malloc(num_iters*src1_stride * sizeof(float), &src1_asf[id]);
+        }
+        if (dst_on_device) {
+            dst_ddf[id] = (float *) dst_extra->data_device[id];
+        } else {
+            size_t size_dst_ddf = split ? row_diff*ne1 * sizeof(float) : num_iters*dst_stride * sizeof(float);
+            dst_ddf[id] = (float *) ggml_cuda_pool_malloc(size_dst_ddf, &dst_asf[id]);
+        }
 
-            // do the computation
-            op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i1, cudaStream_main);
+        for (int64_t i03 = 0; i03 < ne03; i03++) {
+            const int64_t i13 = i03 % ne13;
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                const int64_t i12 = i02 % ne12;
 
-            CUDA_CHECK(cudaEventRecord(cudaEvent_main, cudaStream_main));
+                const int64_t i0 = i03*ne02 + i02;
+                const int64_t i0_offset_low = row_low/ne01;
+                const int64_t i0_offset_high = row_high/ne01;
 
-            // copy dst to host if necessary
-            if (!dst_on_device) {
-                // wait with memcpy until main stream is done
-                CUDA_CHECK(cudaStreamWaitEvent(cudaStream_memcpy_dst, cudaEvent_main, 0));
+                int64_t i01_low = 0;
+                int64_t i01_high = ne01;
+                if (split) {
+                    if (i0 < i0_offset_low || i0 > i0_offset_high) {
+                        continue;
+                    }
+                    if (i0 == i0_offset_low) {
+                        i01_low = row_low % ne01;
+                    }
+                    if (i0 == i0_offset_high) {
+                        i01_high = row_high % ne01;
+                    }
+                }
+                const int64_t i01_diff = i01_high - i01_low;
+                if (i01_diff == 0) {
+                    continue;
+                }
+                const int64_t i1 = i13*ne12 + i12;
+
+                cudaStream_t cudaStream_main = g_cudaStreams_main[id][i0 % GGML_CUDA_MAX_STREAMS];
+                cudaStream_t cudaStream_memcpy_src1 = g_cudaStreams_memcpy_src1[id][i0 % GGML_CUDA_MAX_STREAMS];
+                cudaEvent_t  cudaEvent_memcpy_src1 = g_cudaEvents_memcpy_src1[id][i0 % GGML_CUDA_MAX_EVENTS];
+
+                // for split tensors the data begins at i0 == i0_offset_low
+                char  * src0_ddq_i = src0_ddq[id] + (i0 - i0_offset_low)*src0_stride*src0_ts/src0_bs;
+                float * src0_ddf_i = src0_ddf[id] + (i0 - i0_offset_low)*src0_stride;
+                float * src1_ddf_i = src1_ddf[id] + i1*src1_stride;
+                float * dst_ddf_i = dst_ddf[id] + (i0 - i0_offset_low)*dst_stride;
+
+                // for split tensors the data pointer needs to be rounded down
+                // to the bin edge for i03, i02 bins beyond the first
+                if (i0 - i0_offset_low > 0) {
+                    src0_ddq_i -= (row_low % ne01)*ne00 * src0_ts/src0_bs;
+                    src0_ddf_i -= (row_low % ne01)*ne00;
+                }
+                if (i0 - i0_offset_low > 0) {
+                    dst_ddf_i -= (row_low % ne0)*ne1;
+                }
+
+                // copy src0, src1 to device if necessary
+                if (!src1_on_device) {
+                    CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src1_ddf_i, src1, i03, i02, 0, ne11, cudaStream_memcpy_src1));
+                }
+                CUDA_CHECK(cudaEventRecord(cudaEvent_memcpy_src1, cudaStream_memcpy_src1));
+                if (!src0_on_device) {
+                    if (src0_is_f32) {
+                        CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddf_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    } else {
+                        CUDA_CHECK(ggml_cuda_h2d_tensor_2d(src0_ddq_i, src0, i03, i02, i01_low, i01_high, cudaStream_main));
+                    }
+                }
 
-                float * dhf_dst_i = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
-                CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_memcpy_dst));
+                // convert src0 to f32 if it's necessary for the ggml_cuda_op
+                if (src0_needs_f32 && !src0_is_f32) {
+                    to_fp32_cuda(src0_ddq_i, src0_ddf_i, i01_diff*ne00, cudaStream_main);
+                    CUDA_CHECK(cudaGetLastError());
+                }
+
+                // wait with main stream until src1 memcpy is done
+                CUDA_CHECK(cudaStreamWaitEvent(cudaStream_main, cudaEvent_memcpy_src1, 0));
+
+                // do the computation
+                op(src0, src1, dst, src0_ddq_i, src0_ddf_i, src1_ddf_i, dst_ddf_i, i01_low, i01_high, i1, cudaStream_main);
+
+                // copy dst to host if necessary
+                if (!dst_on_device) {
+                    if (split) {
+                        // src0 = weight matrix is saved as a transposed matrix for better memory layout.
+                        // dst is NOT transposed.
+                        // The outputs of cuBLAS matrix matrix multiplications can therefore NOT simply be concatenated for >1 GPU.
+                        // Instead they need to be copied to the correct slice in ne0 = dst row index.
+                        // If dst is a vector with ne0 == 1 then you don't have to do this but it still produces correct results.
+                        for (int64_t j = 0; j < ne1; ++j) {
+                            float * dhf_dst_i = (float *) ((char *) dst->data + (j*ne0 + i01_low)*sizeof(float) + i02*nb2 + i03*nb3);
+                            CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i + j*i01_diff, i01_diff*sizeof(float),
+                                                    cudaMemcpyDeviceToHost, cudaStream_main));
+                        }
+                    } else {
+                        float * dhf_dst_i = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+                        CUDA_CHECK(cudaMemcpyAsync(dhf_dst_i, dst_ddf_i, dst_stride*sizeof(float), cudaMemcpyDeviceToHost, cudaStream_main));
+                    }
+                }
             }
         }
     }
 
-    CUDA_CHECK(cudaDeviceSynchronize());
-    if (src0_ddf_malloced) {
-        ggml_cuda_pool_free(src0_ddf, src0_asf);
-    }
-    if (src0_ddq_malloced) {
-        ggml_cuda_pool_free(src0_ddq, src0_asq);
-    }
-    if (src1_ddf_malloced) {
-        ggml_cuda_pool_free(src1_ddf, src1_asf);
-    }
-    if (dst_ddf_malloced) {
-        ggml_cuda_pool_free(dst_ddf, dst_asf);
+    // wait until each device is finished, then free their buffers
+    for (int id = 0; id < g_device_count; ++id) {
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaDeviceSynchronize());
+        if (src0_asq[id] > 0) {
+            ggml_cuda_pool_free(src0_ddq[id], src0_asq[id]);
+        }
+        if (src0_asf[id] > 0) {
+            ggml_cuda_pool_free(src0_ddf[id], src0_asf[id]);
+        }
+        if (src1_asf[id] > 0) {
+            ggml_cuda_pool_free(src1_ddf[id], src1_asf[id]);
+        }
+        if (dst_asf[id] > 0) {
+            ggml_cuda_pool_free(dst_ddf[id], dst_asf[id]);
+        }
     }
 }
 
 bool ggml_cuda_can_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    return src1->backend == GGML_BACKEND_CUDA;
+    GGML_ASSERT(src1->backend != GGML_BACKEND_GPU_SPLIT);
+    (void) src0;
+    (void) dst;
+    return src1->backend == GGML_BACKEND_GPU;
 }
 
-void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
-    ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul>(src0, src1, dst);
+    ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true);
 }
 
 bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU);
     const int64_t ne10 = src1->ne[0];
 
     const int64_t ne0 = dst->ne[0];
@@ -881,125 +927,145 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         src1->type == GGML_TYPE_F32 &&
         dst->type == GGML_TYPE_F32 &&
-        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CUDA)) {
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU_SPLIT)) {
         return true;
     }
 
     return false;
 }
 
-bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
-    size_t src0_sz = ggml_nbytes(src0);
-    size_t src1_sz = ggml_nbytes(src1);
-
-    // mul_mat_q: src0 is converted to fp32 on device
-    size_t mul_mat_q_transfer = src0_sz + src1_sz;
-
-    // mul_mat_f16: src1 is converted to fp16 on cpu
-    size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
-
-    // choose the smaller one to transfer to the device
-    // TODO: this is not always the best choice due to the overhead of converting to fp16
-    return mul_mat_f16_transfer < mul_mat_q_transfer;
-}
-
-void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
+void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
 
     if (src0->type == GGML_TYPE_F32) {
-        ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
-    }
-    else if (src0->type == GGML_TYPE_F16) {
-        if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-            // ggml_cuda_op<GGML_CUDA_OP_TYPE_QQF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
-            ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
-        }
-        else {
-            if (src1->ne[1] == 1) {
-                ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
-            } else {
-                ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
-            }
-        }
-    }
-    else if (ggml_is_quantized(src0->type)) {
+        ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
+    } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
         if (src1->ne[1] == 1) {
-            ggml_cuda_op<GGML_CUDA_OP_TYPE_QFF, ggml_cuda_op_dequantize_mul_mat_vec>(src0, src1, dst);
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, false);
         } else {
-            ggml_cuda_op<GGML_CUDA_OP_TYPE_FFF, ggml_cuda_op_mul_mat_cublas>(src0, src1, dst);
+            ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, true);
         }
-    }
-    else {
+    } else {
         GGML_ASSERT(false);
     }
 }
 
-size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
-        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
-    }
-    else {
-        return 0;
-    }
-}
-
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
     FILE * fp = fopen(fname, "rb");
+    int nrows = ggml_nrows(tensor);
+    const size_t nb1 = tensor->nb[1];
+    ggml_backend backend = tensor->backend;
+    struct ggml_tensor_extra_gpu * extra = (struct ggml_tensor_extra_gpu *) tensor->extra;
+
+    for (int id = 0; id < g_device_count; ++id) {
+        extra->data_device[id] = nullptr;
+
+        int layer_low = id == 0 ? 0 : n_layer*g_tensor_split[id];
+        int layer_high = id == g_device_count - 1 ? n_layer : n_layer*g_tensor_split[id + 1];
+        if (backend == GGML_BACKEND_GPU && (extra->layer < layer_low || extra->layer >= layer_high)) {
+            continue;
+        }
 
-    const size_t size = ggml_nbytes(tensor);
+        cudaSetDevice(id);
 
-    void * buf;
-    CUDA_CHECK(cudaMalloc(&buf, size));
-    void * buf_host = malloc(size);
+        int row_low, row_high;
+        if (backend == GGML_BACKEND_GPU) {
+            extra->i_device = id;
+
+            row_low = 0;
+            row_high = nrows;
+        } else if (backend == GGML_BACKEND_GPU_SPLIT) {
+            extra->i_device = -1;
+
+            row_low = id == 0 ? 0 : nrows*g_tensor_split[id];
+            row_low -= row_low % GGML_CUDA_DMMV_Y;
+            row_high = id == g_device_count - 1 ? nrows : nrows*g_tensor_split[id + 1];
+            row_high -= row_high % GGML_CUDA_DMMV_Y;
+        } else {
+            GGML_ASSERT(false);
+        }
+        if (row_low == row_high) {
+            continue;
+        }
+
+        int64_t nrows_split = row_high - row_low;
+
+        const size_t offset_split = offset + row_low*nb1;
+        const size_t size = ggml_nbytes_split(tensor, nrows_split);
+
+        void * buf;
+        CUDA_CHECK(cudaMalloc(&buf, size));
+        void * buf_host = malloc(size);
 
 #ifdef _WIN32
-    int ret = _fseeki64(fp, (__int64) offset, SEEK_SET);
+        int ret = _fseeki64(fp, (__int64) offset_split, SEEK_SET);
 #else
-    int ret = fseek(fp, (long) offset, SEEK_SET);
+        int ret = fseek(fp, (long) offset_split, SEEK_SET);
 #endif
-    GGML_ASSERT(ret == 0); // same
+        GGML_ASSERT(ret == 0); // same
 
-    size_t ret2 = fread(buf_host, size, 1, fp);
-    if (ret2 != 1) {
-        fprintf(stderr, "unexpectedly reached end of file");
-        exit(1);
-    }
+        size_t ret2 = fread(buf_host, size, 1, fp);
+        if (ret2 != 1) {
+            fprintf(stderr, "unexpectedly reached end of file");
+            exit(1);
+        }
+
+        cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
+        cudaDeviceSynchronize();
 
-    cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice);
-    cudaDeviceSynchronize();
+        free(buf_host);
+        extra->data_device[id] = buf;
+    }
 
-    tensor->data = buf;
-    free(buf_host);
+    tensor->extra = extra;
     fclose(fp);
 }
 
+void ggml_cuda_free_data(struct ggml_tensor * tensor) {
+    if (tensor->backend != GGML_BACKEND_GPU && tensor->backend != GGML_BACKEND_GPU_SPLIT) {
+        return;
+    }
+
+    ggml_tensor_extra_gpu * extra = (ggml_tensor_extra_gpu *) tensor->extra;
+
+    for (int id = 0; id < g_device_count; ++id) {
+        if (extra->data_device[id] == nullptr) {
+            continue;
+        }
+
+        CUDA_CHECK(cudaSetDevice(id));
+        CUDA_CHECK(cudaFree(extra->data_device[id]));
+    }
+
+    delete extra;
+}
+
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor){
+    ggml_cuda_func_t func;
+
     switch (tensor->op) {
         case GGML_OP_MUL:
             if (!ggml_cuda_can_mul(tensor->src0, tensor->src1, tensor)) {
                 return false;
             }
-            if (params->ith != 0) {
-                return true;
-            }
-            if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-                return true;
-            }
-            ggml_cuda_mul(tensor->src0, tensor->src1, tensor);
-            return true;
+            func = ggml_cuda_mul;
+            break;
         case GGML_OP_MUL_MAT:
             if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
                 return false;
             }
-            if (params->ith != 0) {
-                return true;
-            }
-            if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
-                return true;
-            }
-            ggml_cuda_mul_mat(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
-            return true;
+            func = ggml_cuda_mul_mat;
+            break;
         default:
             return false;
     }
+
+    if (params->ith != 0) {
+        return true;
+    }
+    if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+        return true;
+    }
+    func(tensor->src0, tensor->src1, tensor);
+    return true;
 }
diff --git a/ggml-cuda.h b/ggml-cuda.h
index f71701ce88371..585e3feaa3078 100644
--- a/ggml-cuda.h
+++ b/ggml-cuda.h
@@ -1,10 +1,21 @@
+#pragma once
+
 #include "ggml.h"
 
 #ifdef  __cplusplus
 extern "C" {
 #endif
 
+#define GGML_CUDA_MAX_DEVICES       16
+
+struct ggml_tensor_extra_gpu {
+    int layer; // which layer the tensor is on
+    int i_device; // which device the data is on
+    void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
+};
+
 void   ggml_init_cublas(void);
+void ggml_cuda_set_tensor_split(float * tensor_split);
 
 void   ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
 bool   ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
@@ -15,7 +26,8 @@ void   ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
 void * ggml_cuda_host_malloc(size_t size);
 void   ggml_cuda_host_free(void * ptr);
 
-void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset);
+void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensors, size_t offset, int n_layer);
+void ggml_cuda_free_data(struct ggml_tensor * tensor);
 bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
 
 #ifdef  __cplusplus
diff --git a/ggml-opencl.cpp b/ggml-opencl.cpp
index 52ba3aaac3f0a..8bc8cbe8cf687 100644
--- a/ggml-opencl.cpp
+++ b/ggml-opencl.cpp
@@ -676,7 +676,7 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
 }
 
 static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    GGML_ASSERT(src1->backend == GGML_BACKEND_CL);
+    GGML_ASSERT(src1->backend == GGML_BACKEND_GPU);
     const int64_t ne00 = src0->ne[0];
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
@@ -789,7 +789,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     size_t y_size;
     size_t d_size;
     cl_mem d_X;
-    if (src0->backend == GGML_BACKEND_CL) {
+    if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
         d_X = (cl_mem) src0->data;
     } else {
         d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size, CL_MEM_READ_ONLY);
@@ -800,7 +800,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             // copy data to device
-            if (src0->backend != GGML_BACKEND_CL) {
+            if (src0->backend != GGML_BACKEND_GPU) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
             }
             CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));
@@ -829,7 +829,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
         }
     }
 
-    if (src0->backend != GGML_BACKEND_CL) {
+    if (src0->backend != GGML_BACKEND_GPU) {
         ggml_cl_pool_free(d_X, x_size);
     }
     ggml_cl_pool_free(d_Y, y_size);
@@ -865,7 +865,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     size_t y_size;
     size_t d_size;
     cl_mem d_X;
-    if (src0->backend == GGML_BACKEND_CL) {
+    if (src0->backend == GGML_BACKEND_GPU) { // NOLINT
         d_X = (cl_mem) src0->data;
     } else {
         d_X = ggml_cl_pool_malloc(sizeof(ggml_fp16_t) * x_ne, &x_size, CL_MEM_READ_ONLY);
@@ -879,7 +879,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
             // copy src0 to device
-            if (src0->backend != GGML_BACKEND_CL) {
+            if (src0->backend != GGML_BACKEND_GPU) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
             }
 
@@ -936,7 +936,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
         }
     }
 
-    if (src0->backend != GGML_BACKEND_CL) {
+    if (src0->backend != GGML_BACKEND_GPU) {
         ggml_cl_pool_free(d_X, x_size);
     }
     ggml_cl_pool_free(d_Y, y_size);
@@ -992,7 +992,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
             if (src0->backend == GGML_BACKEND_CPU) {
                 events.emplace_back();
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
-            } else if (src0->backend == GGML_BACKEND_CL) {
+            } else if (src0->backend == GGML_BACKEND_GPU) {
                 d_Q = (cl_mem) src0->data;
             } else {
                 GGML_ASSERT(false);
@@ -1077,7 +1077,7 @@ bool ggml_cl_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tens
     if ((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
         src1->type == GGML_TYPE_F32 &&
         dst->type == GGML_TYPE_F32 &&
-        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_CL)) {
+        ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_BACKEND_GPU)) {
         return true;
     }
 
@@ -1156,7 +1156,7 @@ void ggml_cl_transform_tensor(ggml_tensor * tensor) {
     CL_CHECK(clFinish(queue));
 
     tensor->data = dst;
-    tensor->backend = GGML_BACKEND_CL;
+    tensor->backend = GGML_BACKEND_GPU;
 }
 
 void ggml_cl_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset) {
diff --git a/ggml.c b/ggml.c
index 4a05df1e3e7bf..9623fa42439b5 100644
--- a/ggml.c
+++ b/ggml.c
@@ -3722,6 +3722,12 @@ size_t ggml_nbytes(const struct ggml_tensor * tensor) {
     return MAX(tensor->ne[3]*tensor->nb[3], (ggml_nelements(tensor)*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type]);
 }
 
+size_t ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split) {
+    static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
+
+    return (nrows_split*tensor->ne[0]*GGML_TYPE_SIZE[tensor->type])/GGML_BLCK_SIZE[tensor->type];
+}
+
 int ggml_blck_size(enum ggml_type type) {
     return GGML_BLCK_SIZE[type];
 }
@@ -4144,6 +4150,7 @@ struct ggml_tensor * ggml_new_tensor_impl(
         /*.perf_time_us =*/ 0,
         /*.data         =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data,
         /*.name         =*/ { 0 },
+        /*.extra        =*/ NULL,
         /*.pad          =*/ { 0 },
     };
 
@@ -8147,7 +8154,7 @@ static void ggml_compute_forward_mul_f32(
     const int nth = params->nth;
 
 #ifdef GGML_USE_CLBLAST
-    if (src1->backend == GGML_BACKEND_CL) {
+    if (src1->backend == GGML_BACKEND_GPU) {
         if (ith == 0) {
             ggml_cl_mul(src0, src1, dst);
         }
@@ -12884,8 +12891,8 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
     GGML_ASSERT(params);
 
 #ifdef GGML_USE_CUBLAS
-    bool used_cuda = ggml_cuda_compute_forward(params, tensor);
-    if (used_cuda) {
+    bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
+    if (skip_cpu) {
         return;
     }
 #endif // GGML_USE_CUBLAS
@@ -14196,7 +14203,6 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
                             node->n_tasks = 1; // TODO: this actually is doing nothing
                                                 //       the threads are still spinning
-                            cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
                         }
                         else
 #elif defined(GGML_USE_CLBLAST)
diff --git a/ggml.h b/ggml.h
index 5dbe0f2ffb78f..f0e8f4f0a8325 100644
--- a/ggml.h
+++ b/ggml.h
@@ -249,8 +249,8 @@ extern "C" {
 
     enum ggml_backend {
         GGML_BACKEND_CPU = 0,
-        GGML_BACKEND_CUDA = 1,
-        GGML_BACKEND_CL = 2,
+        GGML_BACKEND_GPU = 10,
+        GGML_BACKEND_GPU_SPLIT = 20,
     };
 
     // model file types
@@ -375,7 +375,9 @@ extern "C" {
 
         char name[GGML_MAX_NAME];
 
-        char padding[16];
+        void * extra; // extra things e.g. for ggml-cuda.cu
+
+        char padding[4];
     };
 
     static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@@ -424,6 +426,7 @@ extern "C" {
     struct ggml_compute_params {
         enum ggml_task_type type;
 
+        // ith = thread index, nth = number of threads
         int ith, nth;
 
         // work buffer for all threads
@@ -442,9 +445,10 @@ extern "C" {
     GGML_API void    ggml_print_object (const struct ggml_object * obj);
     GGML_API void    ggml_print_objects(const struct ggml_context * ctx);
 
-    GGML_API int64_t ggml_nelements(const struct ggml_tensor * tensor);
-    GGML_API int64_t ggml_nrows    (const struct ggml_tensor * tensor);
-    GGML_API size_t  ggml_nbytes   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nelements   (const struct ggml_tensor * tensor);
+    GGML_API int64_t ggml_nrows       (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes      (const struct ggml_tensor * tensor);
+    GGML_API size_t  ggml_nbytes_split(const struct ggml_tensor * tensor, int nrows_split);
 
     GGML_API int     ggml_blck_size (enum ggml_type type);
     GGML_API size_t  ggml_type_size (enum ggml_type type); // size in bytes for all elements in a block
diff --git a/llama.cpp b/llama.cpp
index bc58ad960c139..3f72624274a51 100644
--- a/llama.cpp
+++ b/llama.cpp
@@ -199,6 +199,12 @@ struct llama_model {
         if (ctx) {
             ggml_free(ctx);
         }
+
+#ifdef GGML_USE_CUBLAS
+        for (size_t i = 0; i < tensors_by_name.size(); ++i) {
+            ggml_cuda_free_data(tensors_by_name[i].second);
+        }
+#endif // GGML_USE_CUBLAS
     }
 };
 
@@ -665,7 +671,7 @@ struct llama_model_loader {
         }
     }
 
-    struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, ggml_backend backend) {
+    struct ggml_tensor * get_tensor(const std::string & name, const std::vector<uint32_t> & ne, int layer, ggml_backend backend) {
         auto it = tensors_map.name_to_idx.find(name);
         if (it == tensors_map.name_to_idx.end()) {
             throw format("llama.cpp: tensor '%s' is missing from model", name.c_str());
@@ -676,10 +682,10 @@ struct llama_model_loader {
                          name.c_str(), llama_format_tensor_shape(ne).c_str(), llama_format_tensor_shape(lt.ne).c_str());
         }
 
-        return get_tensor_for(lt, backend);
+        return get_tensor_for(lt, layer, backend);
     }
 
-    struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, ggml_backend backend) {
+    struct ggml_tensor * get_tensor_for(llama_load_tensor & lt, int layer, ggml_backend backend) {
         struct ggml_tensor * tensor;
         if (lt.ne.size() == 2) {
             tensor = ggml_new_tensor_2d(ggml_ctx, lt.type, lt.ne.at(0), lt.ne.at(1));
@@ -689,6 +695,17 @@ struct llama_model_loader {
         }
         ggml_set_name(tensor, lt.name.c_str());
         LLAMA_ASSERT(lt.ggml_tensor == NULL); // if this fails, we called get_tensor twice on the same tensor
+
+#ifdef GGML_USE_CUBLAS
+        if (backend == GGML_BACKEND_GPU || backend == GGML_BACKEND_GPU_SPLIT) {
+            struct ggml_tensor_extra_gpu * extra = new struct ggml_tensor_extra_gpu;
+            extra->layer = layer;
+            tensor->extra = extra;
+        }
+#else
+        (void) layer;
+#endif // GGML_USE_CUBLAS
+
         tensor->backend = backend;
         lt.ggml_tensor = tensor;
         num_ggml_tensors_created++;
@@ -842,6 +859,7 @@ struct llama_context_params llama_context_default_params() {
     struct llama_context_params result = {
         /*.n_ctx                       =*/ 512,
         /*.gpu_layers                  =*/ 0,
+        /*.tensor_split                =*/ {0},
         /*.seed                        =*/ -1,
         /*.f16_kv                      =*/ true,
         /*.logits_all                  =*/ false,
@@ -926,6 +944,7 @@ static void llama_model_load_internal(
         llama_context & lctx,
         int n_ctx,
         int n_gpu_layers,
+        float * tensor_split,
         ggml_type memory_type,
         bool use_mmap,
         bool use_mlock,
@@ -1019,13 +1038,16 @@ static void llama_model_load_internal(
     }
 
 #if defined(GGML_USE_CUBLAS)
-#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CUDA
     fprintf(stderr, "%s: using CUDA for GPU acceleration\n", __func__);
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU_SPLIT
 #elif defined(GGML_USE_CLBLAST)
-#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CL
     fprintf(stderr, "%s: using OpenCL for GPU acceleration\n", __func__);
+#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_GPU
 #else
 #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_CPU
+#define LLAMA_BACKEND_OFFLOAD_SPLIT GGML_BACKEND_CPU
 #endif
 
     // prepare memory for the weights
@@ -1037,45 +1059,46 @@ static void llama_model_load_internal(
 
         ml->ggml_ctx = ctx;
 
-        model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, GGML_BACKEND_CPU);
-        model.norm           = ml->get_tensor("norm.weight",           {n_embd},          GGML_BACKEND_CPU);
+        model.tok_embeddings = ml->get_tensor("tok_embeddings.weight", {n_embd, n_vocab}, -1, GGML_BACKEND_CPU);
+        model.norm           = ml->get_tensor("norm.weight",           {n_embd},          -1, GGML_BACKEND_CPU);
 
         // "output" tensor
         {
             ggml_backend backend_output;
             if (n_gpu_layers > int(n_layer)) { // NOLINT
-                backend_output = LLAMA_BACKEND_OFFLOAD;
+                backend_output = LLAMA_BACKEND_OFFLOAD_SPLIT;
             } else {
                 backend_output = GGML_BACKEND_CPU;
             }
 
-            model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, backend_output);
+            model.output = ml->get_tensor("output.weight", {n_embd, n_vocab}, -1, backend_output);
         }
 
         const int i_gpu_start = n_layer - n_gpu_layers;
 
         model.layers.resize(n_layer);
         for (uint32_t i = 0; i < n_layer; ++i) {
-            const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD;
+            const ggml_backend backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD; // NOLINT
+            const ggml_backend backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : LLAMA_BACKEND_OFFLOAD_SPLIT; // NOLINT
 
             auto & layer = model.layers[i];
 
             std::string layers_i = "layers." + std::to_string(i);
 
-            layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, backend);
+            layer.attention_norm = ml->get_tensor(layers_i + ".attention_norm.weight", {n_embd}, i, backend);
 
-            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, backend);
-            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, backend);
-            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, backend);
-            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, backend);
+            layer.wq = ml->get_tensor(layers_i + ".attention.wq.weight", {n_embd, n_embd}, i, backend_split);
+            layer.wk = ml->get_tensor(layers_i + ".attention.wk.weight", {n_embd, n_embd}, i, backend_split);
+            layer.wv = ml->get_tensor(layers_i + ".attention.wv.weight", {n_embd, n_embd}, i, backend_split);
+            layer.wo = ml->get_tensor(layers_i + ".attention.wo.weight", {n_embd, n_embd}, i, backend_split);
 
-            layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, backend);
+            layer.ffn_norm = ml->get_tensor(layers_i + ".ffn_norm.weight", {n_embd}, i, backend);
 
-            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff},   backend);
-            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}, backend);
-            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff},   backend);
+            layer.w1 = ml->get_tensor(layers_i + ".feed_forward.w1.weight", {n_embd,   n_ff},   i, backend_split);
+            layer.w2 = ml->get_tensor(layers_i + ".feed_forward.w2.weight", {  n_ff,   n_embd}, i, backend_split);
+            layer.w3 = ml->get_tensor(layers_i + ".feed_forward.w3.weight", {n_embd,   n_ff},   i, backend_split);
 
-            if (backend == LLAMA_BACKEND_OFFLOAD) {
+            if (backend == GGML_BACKEND_GPU) {
                 vram_total +=
                     ggml_nbytes(layer.attention_norm) + ggml_nbytes(layer.wq) + ggml_nbytes(layer.wk)             +
                     ggml_nbytes(layer.wv)             + ggml_nbytes(layer.wo) + ggml_nbytes(layer.attention_norm) +
@@ -1127,6 +1150,8 @@ static void llama_model_load_internal(
 
 #if defined(GGML_USE_CUBLAS)
     {
+        ggml_cuda_set_tensor_split(tensor_split);
+
         size_t done_size = 0;
         size_t data_size = 0;
         for (llama_load_tensor & lt : ml->tensors_map.tensors) {
@@ -1136,13 +1161,14 @@ static void llama_model_load_internal(
             }
         }
         for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            if (lt.ggml_tensor->backend != GGML_BACKEND_CUDA) {
+            ggml_backend backend = lt.ggml_tensor->backend;
+            if (backend != GGML_BACKEND_GPU && backend != GGML_BACKEND_GPU_SPLIT) {
                 continue;
             }
             if (progress_callback) {
                 progress_callback((float) done_size / data_size, progress_callback_user_data);
             }
-            ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off);
+            ggml_cuda_load_data(fname.c_str(), lt.ggml_tensor, lt.shards.at(0).file_off, hparams.n_layer);
             done_size += lt.size;
         }
     }
@@ -1157,7 +1183,7 @@ static void llama_model_load_internal(
             }
         }
         for (llama_load_tensor & lt : ml->tensors_map.tensors) {
-            if (lt.ggml_tensor->backend != GGML_BACKEND_CL) {
+            if (lt.ggml_tensor->backend != GGML_BACKEND_GPU) {
                 continue;
             }
             if (progress_callback) {
@@ -1167,6 +1193,8 @@ static void llama_model_load_internal(
             done_size += lt.size;
         }
     }
+#else
+    (void) tensor_split;
 #endif
 
     if (progress_callback) {
@@ -1185,6 +1213,7 @@ static bool llama_model_load(
         llama_context & lctx,
         int n_ctx,
         int n_gpu_layers,
+        float * tensor_split,
         ggml_type memory_type,
         bool use_mmap,
         bool use_mlock,
@@ -1192,8 +1221,8 @@ static bool llama_model_load(
         llama_progress_callback progress_callback,
         void *progress_callback_user_data) {
     try {
-        llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, memory_type, use_mmap, use_mlock,
-                                  vocab_only, progress_callback, progress_callback_user_data);
+        llama_model_load_internal(fname, lctx, n_ctx, n_gpu_layers, tensor_split, memory_type, use_mmap,
+                                  use_mlock, vocab_only, progress_callback, progress_callback_user_data);
         return true;
     } catch (const std::string & err) {
         fprintf(stderr, "error loading model: %s\n", err.c_str());
@@ -2293,8 +2322,8 @@ struct llama_context * llama_init_from_file(
 
     ggml_type memory_type = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32;
 
-    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, memory_type,
-                params.use_mmap, params.use_mlock, params.vocab_only,
+    if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_gpu_layers, params.tensor_split,
+                memory_type, params.use_mmap, params.use_mlock, params.vocab_only,
                 params.progress_callback, params.progress_callback_user_data)) {
         fprintf(stderr, "%s: failed to load model\n", __func__);
         llama_free(ctx);
@@ -2547,7 +2576,7 @@ int llama_apply_lora_from_file_internal(struct llama_context * ctx, const char *
                 }
                 size_t idx = model_loader->tensors_map.name_to_idx[base_name];
                 llama_load_tensor & lt = model_loader->tensors_map.tensors[idx];
-                base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, GGML_BACKEND_CPU);
+                base_t = model_loader->get_tensor(base_name, { (uint32_t)dest_t->ne[0], (uint32_t)dest_t->ne[1] }, -1, GGML_BACKEND_CPU);
                 lt.data = (uint8_t *) lt.ggml_tensor->data;
                 model_loader->load_data_for(lt);
                 lt.ggml_tensor->data = lt.data;
diff --git a/llama.h b/llama.h
index 87fa9736784c8..015c6a9894ef5 100644
--- a/llama.h
+++ b/llama.h
@@ -1,6 +1,13 @@
 #ifndef LLAMA_H
 #define LLAMA_H
 
+#include "ggml.h"
+#ifdef GGML_USE_CUBLAS
+#include "ggml-cuda.h"
+#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
+#else
+#define LLAMA_MAX_DEVICES 1
+#endif // GGML_USE_CUBLAS
 #include <stddef.h>
 #include <stdint.h>
 #include <stdbool.h>
@@ -65,9 +72,10 @@ extern "C" {
     typedef void (*llama_progress_callback)(float progress, void *ctx);
 
     struct llama_context_params {
-        int n_ctx;        // text context
-        int n_gpu_layers; // number of layers to store in VRAM
-        int seed;         // RNG seed, -1 for random
+        int n_ctx;                            // text context
+        int n_gpu_layers;                     // number of layers to store in VRAM
+        float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs
+        int seed;                             // RNG seed, -1 for random
 
         bool f16_kv;     // use fp16 for KV cache
         bool logits_all; // the llama_eval() call computes all logits, not just the last one

From 11af67866e927ddab9dd86735a7e140fcc93a253 Mon Sep 17 00:00:00 2001
From: JohannesGaessler <johannesg@5d6.de>
Date: Mon, 5 Jun 2023 14:32:20 +0200
Subject: [PATCH 4/4] Fixed single GPU performance regression

---
 ggml-cuda.cu | 133 +++++++++++++++++++++++++++++++++++++++++++++++++++
 ggml.c       |   1 +
 2 files changed, 134 insertions(+)

diff --git a/ggml-cuda.cu b/ggml-cuda.cu
index bccf74664ee1f..38a285e5044a1 100644
--- a/ggml-cuda.cu
+++ b/ggml-cuda.cu
@@ -934,6 +934,30 @@ bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_te
     return false;
 }
 
+bool ggml_cuda_mul_mat_use_f16(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * /* dst */) {
+    size_t src0_sz = ggml_nbytes(src0);
+    size_t src1_sz = ggml_nbytes(src1);
+
+    // mul_mat_q: src0 is converted to fp32 on device
+    size_t mul_mat_q_transfer = src0_sz + src1_sz;
+
+    // mul_mat_f16: src1 is converted to fp16 on cpu
+    size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_nelements(src1);
+
+    // choose the smaller one to transfer to the device
+    // TODO: this is not always the best choice due to the overhead of converting to fp16
+    return mul_mat_f16_transfer < mul_mat_q_transfer;
+}
+
+size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
+    if (ggml_cuda_mul_mat_use_f16(src0, src1, dst)) {
+        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+    }
+    else {
+        return 0;
+    }
+}
+
 void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     GGML_ASSERT(ggml_cuda_can_mul_mat(src0, src1, dst));
 
@@ -950,6 +974,99 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
     }
 }
 
+static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+    const int64_t ne00 = src0->ne[0];
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+
+    const int nb10 = src1->nb[0];
+    const int nb11 = src1->nb[1];
+    const int nb12 = src1->nb[2];
+    const int nb13 = src1->nb[3];
+
+    const int nb2  = dst->nb[2];
+    const int nb3  = dst->nb[3];
+
+    const float alpha = 1.0f;
+    const float beta = 0.0f;
+    const int x_ne = ne01 * ne00;
+    const int y_ne = ne11 * ne10;
+    const int d_ne = ne11 * ne01;
+    const int n_mm = ne03 * ne02;
+
+    size_t x_size, y_size, d_size;
+    half  * d_X =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size);
+    half  * d_Y =  (half *) ggml_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size);
+    float * d_D = (float *) ggml_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size);
+
+    bool src1_cont_rows = nb10 == sizeof(float);
+    bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
+
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        for (int64_t i02 = 0; i02 < ne02; i02++) {
+            int i = i03*ne02 + i02;
+            cudaStream_t cudaStream = g_cudaStreams_main[0][i % GGML_CUDA_MAX_STREAMS];
+
+            half  * c_X = d_X + i * x_ne;
+            half  * c_Y = d_Y + i * y_ne;
+            float * c_D = d_D + i * d_ne;
+
+            // copy src0 to device
+            CUDA_CHECK(ggml_cuda_h2d_tensor_2d(c_X, src0, i03, i02, 0, ne01, cudaStream));
+
+            // convert src1 to fp16
+            // TODO: use multiple threads
+            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02);
+            char * src1i = (char *) src1->data + i03*nb13 + i02*nb12;
+            if (src1_cont_rows) {
+                if (src1_cont_cols) {
+                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
+                }
+                else {
+                    for (int64_t i01 = 0; i01 < ne11; i01++) {
+                        ggml_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10);
+                    }
+                }
+            }
+            else {
+                for (int64_t i01 = 0; i01 < ne11; i01++) {
+                    for (int64_t i00 = 0; i00 < ne10; i00++) {
+                        // very slow due to no inlining
+                        tmp[i01*ne10 + i00] = ggml_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10));
+                    }
+                }
+            }
+
+            // copy src1 to device
+            CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream));
+
+            // compute
+            CUBLAS_CHECK(cublasSetStream(g_cublas_handles[0], cudaStream));
+            CUBLAS_CHECK(
+                cublasGemmEx(g_cublas_handles[0], CUBLAS_OP_T, CUBLAS_OP_N,
+                        ne01, ne11, ne10,
+                        &alpha, c_X, CUDA_R_16F, ne00,
+                                c_Y, CUDA_R_16F, ne10,
+                        &beta,  c_D, CUDA_R_32F, ne01,
+                        CUBLAS_COMPUTE_32F_FAST_16F,
+                        CUBLAS_GEMM_DEFAULT));
+
+            // copy dst to host
+            float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+            CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream));
+        }
+    }
+
+    CUDA_CHECK(cudaDeviceSynchronize());
+    ggml_cuda_pool_free(d_X, x_size);
+    ggml_cuda_pool_free(d_Y, y_size);
+    ggml_cuda_pool_free(d_D, d_size);
+}
+
 void ggml_cuda_load_data(const char * fname, struct ggml_tensor * tensor, const size_t offset, int n_layer) {
     FILE * fp = fopen(fname, "rb");
     int nrows = ggml_nrows(tensor);
@@ -1054,6 +1171,22 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
             if (!ggml_cuda_can_mul_mat(tensor->src0, tensor->src1, tensor)) {
                 return false;
             }
+
+            // For prompt processing the multi GPU code is currently slower than the single GPU code that existed before.
+            // To avoid a performance regression the old code is kept for now:
+            if (g_device_count == 1 && tensor->src0->type == GGML_TYPE_F16 &&
+                ggml_cuda_mul_mat_use_f16(tensor->src0, tensor->src1, tensor)) {
+
+                if (params->ith != 0) {
+                    return true;
+                }
+                if (params->type == GGML_TASK_COMPUTE) {
+                    ggml_cuda_mul_mat_f16(tensor->src0, tensor->src1, tensor, params->wdata, params->wsize);
+                    return true;
+                }
+
+                return false;
+            }
             func = ggml_cuda_mul_mat;
             break;
         default:
diff --git a/ggml.c b/ggml.c
index 9623fa42439b5..cd3784e94d7f6 100644
--- a/ggml.c
+++ b/ggml.c
@@ -14203,6 +14203,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) {
                             node->n_tasks = 1; // TODO: this actually is doing nothing
                                                 //       the threads are still spinning
+                            cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node);
                         }
                         else
 #elif defined(GGML_USE_CLBLAST)