@@ -1284,20 +1284,20 @@ kernel void kernel_ssm_scan_f32(
12841284 const int64_t ng = args.n_group ;
12851285 const int64_t n_t = args.n_seq_tokens ;
12861286
1287- const int64_t s_off = nr * nh * nt * args.n_seqs * sizeof (float );
1287+ const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof (float );
12881288
12891289 device const int32_t * ids = (device const int32_t *) src6;
12901290
12911291 device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03 );
12921292 device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
12931293
12941294 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1295- device const float * x = (device const float *) ((device const char *) src1 + i1*args. nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1296- device const float * dt = (device const float *) ((device const char *) src2 + ir*args. nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1295+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1296+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
12971297 device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {d_state, nh}
12981298 device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
12991299 device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1300- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*args. nb00 ); // {dim, nh, nt, ns}
1300+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
13011301
13021302 const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
13031303 const float x_dt = x[0 ] * dt_soft_plus;
@@ -1354,12 +1354,12 @@ kernel void kernel_ssm_scan_f32_group(
13541354 device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off);
13551355
13561356 for (int64_t i2 = 0 ; i2 < n_t ; ++i2) {
1357- device const float * x = (device const float *) ((device const char *) src1 + i1*args. nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1358- device const float * dt = (device const float *) ((device const char *) src2 + ir*args. nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
1357+ device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13 ); // {dim, nh, nt, ns}
1358+ device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22 ); // {nh, nt, ns}
13591359 device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31 ); // {1, nh}
13601360 device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1 ))*args.nb41 + i2*args.nb42 + i3*args.nb43 ); // {d_state, ng, nt, ns}
13611361 device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1 ))*args.nb51 + i2*args.nb52 + i3*args.nb53 ); // {d_state, ng, nt, ns}
1362- device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*args. nb00 ); // {dim, nh, nt, ns}
1362+ device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t *nh*nr))*nb00); // {dim, nh, nt, ns}
13631363
13641364 const float dt_soft_plus = dt[0 ] <= 20 .0f ? log (1 .0f + exp (dt[0 ])) : dt[0 ];
13651365 const float x_dt = x[0 ] * dt_soft_plus;
0 commit comments