Skip to content

Commit 61f2b2a

Browse files
committed
Address review comments
1 parent 5adf50e commit 61f2b2a

File tree

4 files changed

+63
-56
lines changed

4 files changed

+63
-56
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 25 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -2,35 +2,14 @@
22
#include <cstdint>
33
#include <utility>
44

5-
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
6-
return b;
7-
GGML_UNUSED(a);
8-
}
9-
10-
static __device__ __forceinline__ float op_add(const float a, const float b) {
11-
return a + b;
12-
}
13-
14-
static __device__ __forceinline__ float op_sub(const float a, const float b) {
15-
return a - b;
16-
}
17-
18-
static __device__ __forceinline__ float op_mul(const float a, const float b) {
19-
return a * b;
20-
}
21-
22-
static __device__ __forceinline__ float op_div(const float a, const float b) {
23-
return a / b;
24-
}
25-
26-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... S1Ptrs>
5+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
276
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
287
const int ne0, const int ne1, const int ne2, const int ne3,
298
const int ne10, const int ne11, const int ne12, const int ne13,
309
/*int s0, */ const int s1, const int s2, const int s3,
3110
/*int s00,*/ const int s01, const int s02, const int s03,
3211
/*int s10,*/ const int s11, const int s12, const int s13,
33-
S1Ptrs... src1s) {
12+
src1_ptrs... src1s) {
3413
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
3514
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
3615
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
@@ -67,14 +46,14 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
6746
}
6847
}
6948

70-
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... S1Ptrs>
49+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
7150
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
7251
const int ne0, const int ne1, const int ne2,const int ne3,
7352
const int ne10, const int ne11, const int ne12, const int ne13,
7453
/*int s0, */ const int s1, const int s2, const int s3,
7554
/*int s00,*/ const int s01, const int s02, const int s03,
7655
/*int s10,*/ const int s11, const int s12, const int s13,
77-
S1Ptrs... src1s) {
56+
src1_ptrs ... src1s) {
7857
const int i = blockDim.x*blockIdx.x + threadIdx.x;
7958

8059
const int i3 = i/(ne2*ne1*ne0);
@@ -101,12 +80,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
10180

10281
float result = src0_row ? (float) src0_row[i0] : 0.0f;
10382

104-
auto add_one = [&](const src1_t * p) {
105-
const src1_t * row = p + i_src1;
106-
result = bin_op(result, (float) row[i10]);
107-
return 0;
108-
};
109-
(void) std::initializer_list<int>{ (add_one(src1s), 0)... };
83+
result = (... , bin_op(result, (float)src1s[i_src1 + i10]));
11084

11185
dst_row[i0] = (dst_t) result;
11286
}
@@ -291,7 +265,8 @@ static __global__ void k_repeat_back(
291265
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
292266
}
293267

294-
template <float (*bin_op)(const float, const float), int n_fuse = 1> struct bin_bcast_cuda {
268+
template <float (*bin_op)(const float, const float), int n_fuse = 1>
269+
struct bin_bcast_cuda {
295270
template<typename src0_t, typename src1_t, typename dst_t>
296271
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
297272
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
@@ -355,26 +330,27 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
355330
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
356331
}
357332

358-
template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
333+
template <float (*op)(const float, const float), int n_fuse>
334+
static void ggml_cuda_op_fused_binbcast_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
359335
cudaStream_t stream = ctx.stream();
360336

361337
const ggml_tensor * src0 = dst->src[0];
362338
const ggml_tensor * src1 = dst->src[1];
363339

364340
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
365-
launch_bin_bcast_pack<op_add, float, float, float>(src0, src1, dst,
341+
launch_bin_bcast_pack<op, float, float, float>(src0, src1, dst,
366342
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
367343
stream, std::make_index_sequence<n_fuse>{});
368344
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
369-
launch_bin_bcast_pack<op_add, half, half, half>(src0, src1, dst,
345+
launch_bin_bcast_pack<op, half, half, half>(src0, src1, dst,
370346
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
371347
stream, std::make_index_sequence<n_fuse>{});
372348
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
373-
launch_bin_bcast_pack<op_add, half, float, half>(src0, src1, dst,
349+
launch_bin_bcast_pack<op, half, float, half>(src0, src1, dst,
374350
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
375351
stream, std::make_index_sequence<n_fuse>{});
376352
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
377-
launch_bin_bcast_pack<op_add, half, float, float>(src0, src1, dst,
353+
launch_bin_bcast_pack<op, half, float, float>(src0, src1, dst,
378354
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
379355
stream, std::make_index_sequence<n_fuse>{});
380356
} else {
@@ -385,30 +361,32 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
385361
}
386362
}
387363

388-
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
364+
365+
template<float (*op)(const float, const float)>
366+
void ggml_cuda_op_fused_binbcast(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
389367
GGML_ASSERT(2 <= n_fuse && n_fuse <= 8);
390368

391369
switch (n_fuse) {
392370
case 2:
393-
ggml_cuda_op_fused_add_impl<2>(ctx, dst);
371+
ggml_cuda_op_fused_binbcast_impl<op, 2>(ctx, dst);
394372
break;
395373
case 3:
396-
ggml_cuda_op_fused_add_impl<3>(ctx, dst);
374+
ggml_cuda_op_fused_binbcast_impl<op, 3>(ctx, dst);
397375
break;
398376
case 4:
399-
ggml_cuda_op_fused_add_impl<4>(ctx, dst);
377+
ggml_cuda_op_fused_binbcast_impl<op, 4>(ctx, dst);
400378
break;
401379
case 5:
402-
ggml_cuda_op_fused_add_impl<5>(ctx, dst);
380+
ggml_cuda_op_fused_binbcast_impl<op, 5>(ctx, dst);
403381
break;
404382
case 6:
405-
ggml_cuda_op_fused_add_impl<6>(ctx, dst);
383+
ggml_cuda_op_fused_binbcast_impl<op, 6>(ctx, dst);
406384
break;
407385
case 7:
408-
ggml_cuda_op_fused_add_impl<7>(ctx, dst);
386+
ggml_cuda_op_fused_binbcast_impl<op, 7>(ctx, dst);
409387
break;
410388
case 8:
411-
ggml_cuda_op_fused_add_impl<8>(ctx, dst);
389+
ggml_cuda_op_fused_binbcast_impl<op, 8>(ctx, dst);
412390
break;
413391
default:
414392
GGML_ASSERT(false && "Unsupported n_fuse value");
@@ -445,3 +423,5 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
445423
} break;
446424
}
447425
}
426+
427+
template void ggml_cuda_op_fused_binbcast<op_add>(ggml_backend_cuda_context &, ggml_tensor *, int);

ggml/src/ggml-cuda/binbcast.cuh

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,34 @@
11
#include "common.cuh"
22

3+
4+
__device__ __forceinline__ float op_repeat(const float a, const float b) {
5+
return b;
6+
GGML_UNUSED(a);
7+
}
8+
9+
__device__ __forceinline__ float op_add(const float a, const float b) {
10+
return a + b;
11+
}
12+
13+
__device__ __forceinline__ float op_sub(const float a, const float b) {
14+
return a - b;
15+
}
16+
17+
__device__ __forceinline__ float op_mul(const float a, const float b) {
18+
return a * b;
19+
}
20+
21+
__device__ __forceinline__ float op_div(const float a, const float b) {
22+
return a / b;
23+
}
24+
325
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
426
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
527
void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
628
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
729
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
830

9-
void ggml_cuda_op_fused_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);
10-
1131
void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
32+
33+
template<float (*op)(const float, const float)>
34+
void ggml_cuda_op_fused_binbcast(ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse);

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2817,7 +2817,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
28172817
return false;
28182818
}
28192819

2820-
if (ops.size() >= 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
2820+
if ((ops.size() == 2 || ops.size() == 3) && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) {
28212821
const ggml_tensor *rms_norm = cgraph->nodes[node_idx];
28222822
const ggml_tensor *mul = cgraph->nodes[node_idx+1];
28232823
const ggml_tensor *add = nullptr;
@@ -2905,7 +2905,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29052905

29062906
if (node->op == GGML_OP_ADD) {
29072907
int n_fuse = 0;
2908-
ggml_op ops[8] = {GGML_OP_ADD};
2908+
ggml_op ops[8];
2909+
std::fill(ops, ops + 8, GGML_OP_ADD);
29092910

29102911
for (; n_fuse <= 6; ++n_fuse){
29112912
if (!ggml_can_fuse(cgraph, i + n_fuse, ops + n_fuse, 2)) {
@@ -2926,8 +2927,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx
29262927
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
29272928
}
29282929
cgraph->nodes[i + n_fuse - 1]->data = node->data;
2929-
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
2930+
ggml_cuda_op_fused_binbcast<op_add>(*cuda_ctx, node, n_fuse);
29302931
i += n_fuse - 1;
2932+
29312933
continue;
29322934
}
29332935
}

ggml/src/ggml-cuda/norm.cu

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
178178
const float scale = rsqrtf(mean + eps);
179179

180180
for (int col = tid; col < ncols; col += block_size) {
181-
if constexpr (do_multiply) {
181+
if constexpr (do_multiply && do_add) {
182+
const int mul_col = col % mul_ncols;
183+
const int add_col = col % add_ncols;
184+
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
185+
} else if constexpr (do_multiply) {
182186
const int mul_col = col % mul_ncols;
183187
dst[col] = scale * x[col] * mul[mul_col];
184-
185-
if constexpr (do_add) {
186-
const int add_col = col % add_ncols;
187-
dst[col] += add[add_col];
188-
}
188+
} else if constexpr (do_add) {
189+
const int add_col = col % add_ncols;
190+
dst[col] += add[add_col];
189191
} else {
190192
dst[col] = scale * x[col];
191193
}

0 commit comments

Comments
 (0)