@@ -111,12 +111,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
111
111
dst_row[i0] = (dst_t ) result;
112
112
}
113
113
114
- template <float (*bin_op)(const float , const float ),
115
- int n_fuse,
116
- typename src0_t ,
117
- typename src1_t ,
118
- typename dst_t ,
119
- size_t ... I>
114
+ template <float (*bin_op)(const float , const float ), typename src0_t , typename src1_t , typename dst_t , size_t ... I>
120
115
static void launch_bin_bcast_pack (const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
121
116
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
122
117
cudaStream_t stream, std::index_sequence<I...>) {
@@ -301,7 +296,7 @@ template <float (*bin_op)(const float, const float), int n_fuse = 1> struct bin_
301
296
void operator ()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
302
297
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
303
298
cudaStream_t stream) {
304
- launch_bin_bcast_pack<bin_op, n_fuse, src0_t , src1_t , dst_t >(
299
+ launch_bin_bcast_pack<bin_op, src0_t , src1_t , dst_t >(
305
300
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
306
301
}
307
302
};
@@ -367,19 +362,19 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
367
362
const ggml_tensor * src1 = dst->src [1 ];
368
363
369
364
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
370
- launch_bin_bcast_pack<op_add, n_fuse, float , float , float >(src0, src1, dst,
365
+ launch_bin_bcast_pack<op_add, float , float , float >(src0, src1, dst,
371
366
(const float *) src0->data , (const float *) src1->data , (float *) dst->data ,
372
367
stream, std::make_index_sequence<n_fuse>{});
373
368
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
374
- launch_bin_bcast_pack<op_add, n_fuse, half, half, half>(src0, src1, dst,
369
+ launch_bin_bcast_pack<op_add, half, half, half>(src0, src1, dst,
375
370
(const half *) src0->data , (const half *) src1->data , (half *) dst->data ,
376
371
stream, std::make_index_sequence<n_fuse>{});
377
372
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
378
- launch_bin_bcast_pack<op_add, n_fuse, half, float , half>(src0, src1, dst,
373
+ launch_bin_bcast_pack<op_add, half, float , half>(src0, src1, dst,
379
374
(const half *) src0->data , (const float *) src1->data , (half *) dst->data ,
380
375
stream, std::make_index_sequence<n_fuse>{});
381
376
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
382
- launch_bin_bcast_pack<op_add, n_fuse, half, float , float >(src0, src1, dst,
377
+ launch_bin_bcast_pack<op_add, half, float , float >(src0, src1, dst,
383
378
(const half *) src0->data , (const float *) src1->data , (float *) dst->data ,
384
379
stream, std::make_index_sequence<n_fuse>{});
385
380
} else {
0 commit comments