Skip to content

Commit 5adf50e

Browse files
committed
Remove n_fuse from template params
1 parent 4d10578 commit 5adf50e

File tree

1 file changed

+6
-11
lines changed

1 file changed

+6
-11
lines changed

ggml/src/ggml-cuda/binbcast.cu

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -111,12 +111,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
111111
dst_row[i0] = (dst_t) result;
112112
}
113113

114-
template <float (*bin_op)(const float, const float),
115-
int n_fuse,
116-
typename src0_t,
117-
typename src1_t,
118-
typename dst_t,
119-
size_t... I>
114+
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, size_t... I>
120115
static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
121116
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
122117
cudaStream_t stream, std::index_sequence<I...>) {
@@ -301,7 +296,7 @@ template <float (*bin_op)(const float, const float), int n_fuse = 1> struct bin_
301296
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
302297
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
303298
cudaStream_t stream) {
304-
launch_bin_bcast_pack<bin_op, n_fuse, src0_t, src1_t, dst_t>(
299+
launch_bin_bcast_pack<bin_op, src0_t, src1_t, dst_t>(
305300
src0, src1, dst, src0_dd, src1_dd, dst_dd, stream, std::make_index_sequence<n_fuse>{});
306301
}
307302
};
@@ -367,19 +362,19 @@ template <int n_fuse> static void ggml_cuda_op_fused_add_impl(ggml_backend_cuda_
367362
const ggml_tensor * src1 = dst->src[1];
368363

369364
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
370-
launch_bin_bcast_pack<op_add, n_fuse, float, float, float>(src0, src1, dst,
365+
launch_bin_bcast_pack<op_add, float, float, float>(src0, src1, dst,
371366
(const float *) src0->data, (const float *) src1->data, (float *) dst->data,
372367
stream, std::make_index_sequence<n_fuse>{});
373368
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
374-
launch_bin_bcast_pack<op_add, n_fuse, half, half, half>(src0, src1, dst,
369+
launch_bin_bcast_pack<op_add, half, half, half>(src0, src1, dst,
375370
(const half *) src0->data, (const half *) src1->data, (half *) dst->data,
376371
stream, std::make_index_sequence<n_fuse>{});
377372
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
378-
launch_bin_bcast_pack<op_add, n_fuse, half, float, half>(src0, src1, dst,
373+
launch_bin_bcast_pack<op_add, half, float, half>(src0, src1, dst,
379374
(const half *) src0->data, (const float *) src1->data, (half *) dst->data,
380375
stream, std::make_index_sequence<n_fuse>{});
381376
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
382-
launch_bin_bcast_pack<op_add, n_fuse, half, float, float>(src0, src1, dst,
377+
launch_bin_bcast_pack<op_add, half, float, float>(src0, src1, dst,
383378
(const half *) src0->data, (const float *) src1->data, (float *) dst->data,
384379
stream, std::make_index_sequence<n_fuse>{});
385380
} else {

0 commit comments

Comments
 (0)