2
2
#include < cstdint>
3
3
#include < utility>
4
4
5
- static __device__ __forceinline__ float op_repeat (const float a, const float b) {
6
- return b;
7
- GGML_UNUSED (a);
8
- }
9
-
10
- static __device__ __forceinline__ float op_add (const float a, const float b) {
11
- return a + b;
12
- }
13
-
14
- static __device__ __forceinline__ float op_sub (const float a, const float b) {
15
- return a - b;
16
- }
17
-
18
- static __device__ __forceinline__ float op_mul (const float a, const float b) {
19
- return a * b;
20
- }
21
-
22
- static __device__ __forceinline__ float op_div (const float a, const float b) {
23
- return a / b;
24
- }
25
-
26
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... S1Ptrs>
5
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs>
27
6
static __global__ void k_bin_bcast (const src0_t * src0, const src1_t * src1, dst_t * dst,
28
7
const int ne0, const int ne1, const int ne2, const int ne3,
29
8
const int ne10, const int ne11, const int ne12, const int ne13,
30
9
/* int s0, */ const int s1, const int s2, const int s3,
31
10
/* int s00,*/ const int s01, const int s02, const int s03,
32
11
/* int s10,*/ const int s11, const int s12, const int s13,
33
- S1Ptrs ... src1s) {
12
+ src1_ptrs ... src1s) {
34
13
const int i0s = blockDim .x *blockIdx .x + threadIdx .x ;
35
14
const int i1 = (blockDim .y *blockIdx .y + threadIdx .y );
36
15
const int i2 = (blockDim .z *blockIdx .z + threadIdx .z ) / ne3;
@@ -67,14 +46,14 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
67
46
}
68
47
}
69
48
70
- template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... S1Ptrs >
49
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , typename ... src1_ptrs >
71
50
static __global__ void k_bin_bcast_unravel (const src0_t * src0, const src1_t * src1, dst_t * dst,
72
51
const int ne0, const int ne1, const int ne2,const int ne3,
73
52
const int ne10, const int ne11, const int ne12, const int ne13,
74
53
/* int s0, */ const int s1, const int s2, const int s3,
75
54
/* int s00,*/ const int s01, const int s02, const int s03,
76
55
/* int s10,*/ const int s11, const int s12, const int s13,
77
- S1Ptrs ... src1s) {
56
+ src1_ptrs ... src1s) {
78
57
const int i = blockDim .x *blockIdx .x + threadIdx .x ;
79
58
80
59
const int i3 = i/(ne2*ne1*ne0);
@@ -101,12 +80,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
101
80
102
81
float result = src0_row ? (float ) src0_row[i0] : 0 .0f ;
103
82
104
- auto add_one = [&](const src1_t * p) {
105
- const src1_t * row = p + i_src1;
106
- result = bin_op (result, (float ) row[i10]);
107
- return 0 ;
108
- };
109
- (void ) std::initializer_list<int >{ (add_one (src1s), 0 )... };
83
+ result = (... , bin_op (result, (float )src1s[i_src1 + i10]));
110
84
111
85
dst_row[i0] = (dst_t ) result;
112
86
}
@@ -291,7 +265,8 @@ static __global__ void k_repeat_back(
291
265
dst[tid3*ne2*ne1*ne0 + tid2*ne1*ne0 + tid1*ne0 + tid0] = sum;
292
266
}
293
267
294
- template <float (*bin_op)(const float , const float ), int n_fuse = 1 > struct bin_bcast_cuda {
268
+ template <float (*bin_op)(const float , const float ), int n_fuse = 1 >
269
+ struct bin_bcast_cuda {
295
270
template <typename src0_t , typename src1_t , typename dst_t >
296
271
void operator ()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
297
272
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
@@ -355,26 +330,27 @@ void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
355
330
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src [0 ], dst->src [1 ], dst, dst->src [0 ]->data , dst->src [1 ]->data , dst->data , ctx.stream ());
356
331
}
357
332
358
- template <int n_fuse> static void ggml_cuda_op_fused_add_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
333
+ template <float (*op)(const float , const float ), int n_fuse>
334
+ static void ggml_cuda_op_fused_binbcast_impl (ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
359
335
cudaStream_t stream = ctx.stream ();
360
336
361
337
const ggml_tensor * src0 = dst->src [0 ];
362
338
const ggml_tensor * src1 = dst->src [1 ];
363
339
364
340
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
365
- launch_bin_bcast_pack<op_add , float , float , float >(src0, src1, dst,
341
+ launch_bin_bcast_pack<op , float , float , float >(src0, src1, dst,
366
342
(const float *) src0->data , (const float *) src1->data , (float *) dst->data ,
367
343
stream, std::make_index_sequence<n_fuse>{});
368
344
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
369
- launch_bin_bcast_pack<op_add , half, half, half>(src0, src1, dst,
345
+ launch_bin_bcast_pack<op , half, half, half>(src0, src1, dst,
370
346
(const half *) src0->data , (const half *) src1->data , (half *) dst->data ,
371
347
stream, std::make_index_sequence<n_fuse>{});
372
348
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
373
- launch_bin_bcast_pack<op_add , half, float , half>(src0, src1, dst,
349
+ launch_bin_bcast_pack<op , half, float , half>(src0, src1, dst,
374
350
(const half *) src0->data , (const float *) src1->data , (half *) dst->data ,
375
351
stream, std::make_index_sequence<n_fuse>{});
376
352
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
377
- launch_bin_bcast_pack<op_add , half, float , float >(src0, src1, dst,
353
+ launch_bin_bcast_pack<op , half, float , float >(src0, src1, dst,
378
354
(const half *) src0->data , (const float *) src1->data , (float *) dst->data ,
379
355
stream, std::make_index_sequence<n_fuse>{});
380
356
} else {
@@ -385,30 +361,32 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
385
361
}
386
362
}
387
363
388
- void ggml_cuda_op_fused_add (ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
364
+
365
+ template <float (*op)(const float , const float )>
366
+ void ggml_cuda_op_fused_binbcast (ggml_backend_cuda_context & ctx, ggml_tensor * dst, int n_fuse) {
389
367
GGML_ASSERT (2 <= n_fuse && n_fuse <= 8 );
390
368
391
369
switch (n_fuse) {
392
370
case 2 :
393
- ggml_cuda_op_fused_add_impl< 2 >(ctx, dst);
371
+ ggml_cuda_op_fused_binbcast_impl<op, 2 >(ctx, dst);
394
372
break ;
395
373
case 3 :
396
- ggml_cuda_op_fused_add_impl< 3 >(ctx, dst);
374
+ ggml_cuda_op_fused_binbcast_impl<op, 3 >(ctx, dst);
397
375
break ;
398
376
case 4 :
399
- ggml_cuda_op_fused_add_impl< 4 >(ctx, dst);
377
+ ggml_cuda_op_fused_binbcast_impl<op, 4 >(ctx, dst);
400
378
break ;
401
379
case 5 :
402
- ggml_cuda_op_fused_add_impl< 5 >(ctx, dst);
380
+ ggml_cuda_op_fused_binbcast_impl<op, 5 >(ctx, dst);
403
381
break ;
404
382
case 6 :
405
- ggml_cuda_op_fused_add_impl< 6 >(ctx, dst);
383
+ ggml_cuda_op_fused_binbcast_impl<op, 6 >(ctx, dst);
406
384
break ;
407
385
case 7 :
408
- ggml_cuda_op_fused_add_impl< 7 >(ctx, dst);
386
+ ggml_cuda_op_fused_binbcast_impl<op, 7 >(ctx, dst);
409
387
break ;
410
388
case 8 :
411
- ggml_cuda_op_fused_add_impl< 8 >(ctx, dst);
389
+ ggml_cuda_op_fused_binbcast_impl<op, 8 >(ctx, dst);
412
390
break ;
413
391
default :
414
392
GGML_ASSERT (false && " Unsupported n_fuse value" );
@@ -445,3 +423,5 @@ void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst
445
423
} break ;
446
424
}
447
425
}
426
+
427
+ template void ggml_cuda_op_fused_binbcast<op_add>(ggml_backend_cuda_context &, ggml_tensor *, int );
0 commit comments