Skip to content

Commit 3b9ea65

Browse files
committed
cuda : use CUBLAS_COMPUTE_32F to speed-up and avoid dst cpy
1 parent c8d6a1f commit 3b9ea65

File tree

1 file changed

+21
-36
lines changed

1 file changed

+21
-36
lines changed

ggml-cuda.cu

Lines changed: 21 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -6385,27 +6385,19 @@ inline void ggml_cuda_op_mul_mat_cublas(
63856385
}
63866386
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddq_i : src1_as_f16;
63876387

6388-
size_t dst_as = 0;
6389-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(row_diff*src1_ncols * sizeof(half), &dst_as);
6390-
6391-
const half alpha_f16 = 1.0f;
6392-
const half beta_f16 = 0.0f;
6388+
const float alpha = 1.0f;
6389+
const float beta = 0.0f;
63936390

63946391
CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], stream));
63956392
CUBLAS_CHECK(
63966393
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
63976394
row_diff, src1_ncols, ne10,
6398-
&alpha_f16, src0_ptr, CUDA_R_16F, ne00,
6399-
src1_ptr, CUDA_R_16F, ne10,
6400-
&beta_f16, dst_f16, CUDA_R_16F, ldc,
6401-
CUBLAS_COMPUTE_16F,
6395+
&alpha, src0_ptr, CUDA_R_16F, ne00,
6396+
src1_ptr, CUDA_R_16F, ne10,
6397+
&beta, dst_dd_i, CUDA_R_32F, ldc,
6398+
CUBLAS_COMPUTE_32F,
64026399
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
64036400

6404-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
6405-
to_fp32_cuda(dst_f16, dst_dd_i, row_diff*src1_ncols, stream);
6406-
6407-
ggml_cuda_pool_free(dst_f16, dst_as);
6408-
64096401
if (src0_as != 0) {
64106402
ggml_cuda_pool_free(src0_as_f16, src0_as);
64116403
}
@@ -7189,18 +7181,15 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
71897181
half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
71907182
to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
71917183

7192-
size_t dst_as = 0;
7193-
half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
7194-
71957184
GGML_ASSERT(ne12 % ne02 == 0);
71967185
GGML_ASSERT(ne13 % ne03 == 0);
71977186

71987187
// broadcast factors
71997188
const int64_t r2 = ne12/ne02;
72007189
const int64_t r3 = ne13/ne03;
72017190

7202-
const half alpha_f16 = 1.0f;
7203-
const half beta_f16 = 0.0f;
7191+
const float alpha = 1.0f;
7192+
const float beta = 0.0f;
72047193

72057194
#if 0
72067195
// use cublasGemmEx
@@ -7213,10 +7202,10 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72137202
CUBLAS_CHECK(
72147203
cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
72157204
ne01, ne11, ne10,
7216-
&alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7217-
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7218-
&beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
7219-
CUBLAS_COMPUTE_16F,
7205+
&alpha, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half),
7206+
(const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
7207+
&beta, ( char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] , CUDA_R_32F, ne01,
7208+
CUBLAS_COMPUTE_32F,
72207209
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
72217210
}
72227211
}
@@ -7228,11 +7217,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72287217
CUBLAS_CHECK(
72297218
cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
72307219
ne01, ne11, ne10,
7231-
&alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7232-
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7233-
&beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC
7220+
&alpha, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA
7221+
(const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
7222+
&beta, ( char *) dst_ddf, CUDA_R_32F, ne01, dst->nb[2]/sizeof(float), // strideC
72347223
ne12*ne13,
7235-
CUBLAS_COMPUTE_16F,
7224+
CUBLAS_COMPUTE_32F,
72367225
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
72377226
} else {
72387227
// use cublasGemmBatchedEx
@@ -7249,7 +7238,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72497238

72507239
ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3];
72517240
ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
7252-
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
7241+
ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_ddf + i12* dst->nb[2] + i13* dst->nb[3] ;
72537242
}
72547243
}
72557244

@@ -7269,11 +7258,11 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72697258
CUBLAS_CHECK(
72707259
cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
72717260
ne01, ne11, ne10,
7272-
&alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7273-
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7274-
&beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
7261+
&alpha, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
7262+
(const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
7263+
&beta, ( void **) (ptrs_as + 2*ne23), CUDA_R_32F, ne01,
72757264
ne23,
7276-
CUBLAS_COMPUTE_16F,
7265+
CUBLAS_COMPUTE_32F,
72777266
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
72787267

72797268
// free device memory for pointers
@@ -7282,11 +7271,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
72827271
}
72837272
#endif
72847273

7285-
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
7286-
to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
7287-
72887274
ggml_cuda_pool_free(src1_as_f16, src1_as);
7289-
ggml_cuda_pool_free(dst_f16, dst_as);
72907275
}
72917276

72927277
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

0 commit comments

Comments
 (0)