Skip to content

Commit e9b0d19

Browse files
committed
llama : fix Mamba-2 conv state saving
* ggml : make the ggml_mul fast broadcast path more consistently formatted
1 parent 3326dc4 commit e9b0d19

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

ggml/src/ggml.c

+2-2
Original file line numberDiff line numberDiff line change
@@ -10226,11 +10226,11 @@ static void ggml_compute_forward_mul_f32(
1022610226
if (scale == 0.0f) {
1022710227
// NOTE: this also sets NANs to zero, which is not compliant with IEEE754,
1022810228
// but it is useful when resetting the state of recurrent models.
10229-
memset((char *)dst->data + ir*nb1, 0, nb1);
10229+
memset((char *) dst->data + ir*nb1, 0, ne0 * sizeof(float));
1023010230
} else {
1023110231
if (dst->data != src0->data) {
1023210232
// src0 is same shape as dst => same indices
10233-
memcpy((char *)dst->data + ir*nb1, (char *)src0->data + ir*nb01, ne0 * sizeof(float));
10233+
memcpy((char *) dst->data + ir*nb1, (char *) src0->data + ir*nb01, ne0 * sizeof(float));
1023410234
}
1023510235
if (scale != 1.0f) {
1023610236
ggml_vec_scale_f32(ne0, (float *) ((char *) dst->data + ir*nb1), scale);

src/llama.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -9335,7 +9335,7 @@ static struct ggml_tensor * llm_build_mamba2(
93359335
ggml_cpy(ctx, last_conv,
93369336
ggml_view_1d(ctx, conv_states_all,
93379337
(d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs),
9338-
kv_head*(d_conv - 1)*(d_inner)*ggml_element_size(conv_states_all))));
9338+
kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all))));
93399339

93409340
// 1D convolution
93419341
// The equivalent is to make a self-overlapping view of conv_x

0 commit comments

Comments
 (0)