Skip to content

Commit 791998b

Browse files
committed
metal : single-user mamba2 inference works
1 parent 6def5cd commit 791998b

File tree

2 files changed

+8
-8
lines changed

2 files changed

+8
-8
lines changed

ggml/src/ggml-metal/ggml-metal.metal

+7-7
Original file line numberDiff line numberDiff line change
@@ -1284,20 +1284,20 @@ kernel void kernel_ssm_scan_f32(
12841284
const int64_t ng = args.n_group;
12851285
const int64_t n_t = args.n_seq_tokens;
12861286

1287-
const int64_t s_off = nr * nh * nt * args.n_seqs * sizeof(float);
1287+
const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float);
12881288

12891289
device const int32_t * ids = (device const int32_t *) src6;
12901290

12911291
device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03);
12921292
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
12931293

12941294
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1295-
device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1296-
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1295+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1296+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
12971297
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh}
12981298
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
12991299
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1300-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
1300+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
13011301

13021302
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
13031303
const float x_dt = x[0] * dt_soft_plus;
@@ -1354,12 +1354,12 @@ kernel void kernel_ssm_scan_f32_group(
13541354
device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
13551355

13561356
for (int64_t i2 = 0; i2 < n_t; ++i2) {
1357-
device const float * x = (device const float *) ((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1358-
device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
1357+
device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns}
1358+
device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns}
13591359
device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh}
13601360
device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns}
13611361
device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns}
1362-
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*args.nb00); // {dim, nh, nt, ns}
1362+
device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns}
13631363

13641364
const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0];
13651365
const float x_dt = x[0] * dt_soft_plus;

src/llama-model.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9009,7 +9009,7 @@ struct llm_build_mamba : public llm_graph_context {
90099009
ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC));
90109010

90119011
// {n_head, n_seq_tokens, n_seqs}
9012-
dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b);
9012+
dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b);
90139013

90149014
ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0);
90159015
// TODO: use semistructured matrices to implement state-space duality

0 commit comments

Comments
 (0)