diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index df3f8a55d5320..2debb6e63fef9 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -4127,6 +4127,14 @@ def set_gguf_parameters(self): class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + def set_vocab(self): vocab_size = self.hparams["vocab_size"] # Round vocab size to next multiple of 8 @@ -4201,6 +4209,100 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter return [(new_name, data_torch)] +@ModelBase.register("Mamba2ForCausalLM") +class Mamba2Model(TextModel): + model_arch = gguf.MODEL_ARCH.MAMBA2 + + def __init__(self, dir_model: Path, *args, **kwargs): + # Avoid using AutoConfig for hparams + # It wrongly assumes all Mamba2 models are Mamba-Codestral-7B-v0.1 + hparams = kwargs.pop("hparams", None) + if hparams is None: + with open(dir_model / "config.json", "r", encoding="utf-8") as f: + hparams = json.load(f) + super().__init__(dir_model, *args, hparams=hparams, **kwargs) + + def set_vocab(self): + vocab_size = self.hparams["vocab_size"] + # Round vocab size to next multiple of 16 + pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16) + # pad using ceiling division + # ref: https://stackoverflow.com/a/17511341/22827863 + vocab_size = -(vocab_size // -pad_vocab) * pad_vocab + self.hparams["vocab_size"] = vocab_size + + if (self.dir_model / "tokenizer.model").is_file(): + self._set_vocab_sentencepiece() + elif (self.dir_model / "tokenizer.model.v3").is_file(): + # mamba-codestral + raise NotImplementedError(f"Please rename {self.dir_model / 'tokenizer.model.v3'} to {self.dir_model / 'tokenizer.model'}") + elif (self.dir_model / "tokenizer.json").is_file(): + self._set_vocab_gpt2() + else: + # Use the GPT-NeoX tokenizer when no tokenizer files are present + self._set_vocab_builtin("gpt-neox", vocab_size) + + def set_gguf_parameters(self): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_conv = self.find_hparam(["conv_kernel", "d_conv"], optional=True) or 4 + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + d_state = self.find_hparam(["state_size", "d_state"], optional=True) or 128 + head_dim = self.find_hparam(["head_dim"], optional=True) or 64 + n_group = self.find_hparam(["n_groups"], optional=True) or 1 + + rms_norm_eps = self.find_hparam(["layer_norm_epsilon", "rms_norm_eps"], optional=True) or 1e-5 + + # Fail early for models which don't have a block expansion factor of 2 + # TODO: does this really matter? + assert d_inner == 2 * d_model + assert d_inner % head_dim == 0 + + self.gguf_writer.add_context_length(2**20) # arbitrary value; for those who use the default + self.gguf_writer.add_embedding_length(d_model) + self.gguf_writer.add_feed_forward_length(0) # unused, but seemingly required when loading + self.gguf_writer.add_head_count(0) # unused, but seemingly required when loading + self.gguf_writer.add_block_count(self.block_count) + self.gguf_writer.add_ssm_conv_kernel(d_conv) + self.gguf_writer.add_ssm_inner_size(d_inner) + self.gguf_writer.add_ssm_state_size(d_state) + self.gguf_writer.add_ssm_time_step_rank(d_inner // head_dim) + self.gguf_writer.add_ssm_group_count(n_group) + self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) + self.gguf_writer.add_file_type(self.ftype) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + + if name.startswith("model.backbone") or name.startswith("model.lm_head"): + # map Mamba-Codestral-7B-v0.1 tensor names to the names used by Mamba-2 + name = name.removeprefix("model.") + + if name.endswith(".dt_bias"): + name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias" + + new_name = self.map_tensor_name(name) + + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_CONV1D, bid): + data_torch = data_torch.squeeze() + elif any(self.match_model_tensor_name(new_name, t, bid, suffix="") for t in [ + gguf.MODEL_TENSOR.SSM_A, + gguf.MODEL_TENSOR.SSM_D, + ]): + # unsqueeze A to use similar shape semantics as Mamba-1 + # (D is also unsqueezed, but for more straightforward broadcast internally) + data_torch = data_torch.reshape((*data_torch.shape, 1)) + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.SSM_NORM, bid): + d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) + d_inner = self.find_hparam(["intermediate_size", "d_inner"], optional=True) or 2 * d_model + n_group = self.hparams.get("n_groups", 1) + data_torch = data_torch.reshape((n_group, d_inner // n_group)) + + if name.endswith(".A_log"): + logger.debug("A_log --> A ==> " + new_name) + data_torch = -torch.exp(data_torch) + + yield (new_name, data_torch) + + @ModelBase.register("CohereForCausalLM") class CommandR2Model(TextModel): model_arch = gguf.MODEL_ARCH.COMMAND_R @@ -5883,12 +5985,20 @@ def get_model_architecture(dir_model: Path, model_type: ModelType, hparams: Any hparams = ModelBase.load_hparams(dir_model) if hparams is None else hparams text_config = hparams.get("text_config", {}) vision_config = hparams.get("vision_config", {}) - arch = hparams["architectures"][0] + arch = None + if (arches := hparams.get("architectures")) is not None and len(arches) > 0: + arch = arches[0] + elif "ssm_cfg" in hparams: + # For non-hf Mamba and Mamba2 models + arch = hparams["ssm_cfg"].get("layer", "Mamba") + "ForCausalLM" + # if "architectures" is found in the sub-config, use that instead if model_type == ModelType.TEXT and text_config.get("architectures") is not None: arch = text_config["architectures"][0] elif model_type == ModelType.VISION and vision_config.get("architectures") is not None: arch = vision_config["architectures"][0] + if arch is None: + raise ValueError("Failed to detect model architecture") return arch diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 1b8603e78e553..c75f4331857c8 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -1854,7 +1854,8 @@ extern "C" { struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C); + struct ggml_tensor * C, + struct ggml_tensor * ids); // partition into non-overlapping windows with padding if needed // example: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 7413192b746b6..8518c3d7911de 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -7506,74 +7506,151 @@ void ggml_compute_forward_ssm_conv( static void ggml_compute_forward_ssm_scan_f32( const ggml_compute_params * params, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; // s - const ggml_tensor * src1 = dst->src[1]; // x - const ggml_tensor * src2 = dst->src[2]; // dt - const ggml_tensor * src3 = dst->src[3]; // A - const ggml_tensor * src4 = dst->src[4]; // B - const ggml_tensor * src5 = dst->src[5]; // C + const ggml_tensor * src0 = dst->src[0]; // s {d_state, dim, n_head, n_seqs+} + const ggml_tensor * src1 = dst->src[1]; // x {dim, n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src2 = dst->src[2]; // dt {n_head, n_seq_tokens, n_seqs} + const ggml_tensor * src3 = dst->src[3]; // A {d_state, n_head} or {1, n_head} + const ggml_tensor * src4 = dst->src[4]; // B {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src5 = dst->src[5]; // C {d_state, n_group, n_seq_tokens, n_seqs} + const ggml_tensor * src6 = dst->src[6]; // ids {n_seqs} const int ith = params->ith; const int nth = params->nth; - const int64_t nc = src0->ne[0]; // d_state - const int64_t nr = src0->ne[1]; // d_inner - const int64_t n_t = src1->ne[1]; // number of tokens per sequence - const int64_t n_s = src0->ne[2]; // number of sequences in the batch + const int64_t nc = src0->ne[0]; // d_state + const int64_t nr = src0->ne[1]; // dim + const int64_t nh = src1->ne[1]; // n_head + const int64_t ng = src4->ne[1]; + const int64_t nt = src1->ne[2]; // number of tokens per sequence + const int64_t ns = src1->ne[3]; // number of sequences in the batch - GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst)); + // can't use ggml_nbytes because src1 is not necessarily contiguous + const int64_t s_off = ggml_nelements(src1) * ggml_element_size(src1); + + GGML_ASSERT(ggml_nelements(src1) + nc*nr*nh*ns == ggml_nelements(dst)); GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src1->nb[0] == sizeof(float)); GGML_ASSERT(src2->nb[0] == sizeof(float)); GGML_ASSERT(src3->nb[0] == sizeof(float)); GGML_ASSERT(src4->nb[0] == sizeof(float)); GGML_ASSERT(src5->nb[0] == sizeof(float)); - // required for the dot product between s and C - GGML_ASSERT(src0->nb[1] == src0->ne[0]*sizeof(float)); - // required for per-sequence offsets for states - GGML_ASSERT(src0->nb[2] == src0->ne[0]*src0->ne[1]*sizeof(float)); - // required to get correct offset for state destination (i.e. src1->nb[3]) - GGML_ASSERT(src1->nb[3] == src1->ne[0]*src1->ne[1]*src1->ne[2]*sizeof(float)); + GGML_ASSERT(src6->nb[0] == sizeof(int32_t)); + // allows optimizing the modulo since n_group should be a power of 2 + GGML_ASSERT((ng & -ng) == ng); + + // heads per thread + const int dh = (nh + nth - 1)/nth; + + // head range for this thread + const int ih0 = dh*ith; + const int ih1 = MIN(ih0 + dh, nh); + + const int32_t * ids = (const int32_t *) src6->data; + + for (int i3 = 0; i3 < ns; ++i3) { + const float * s0 = (const float *) ((const char *) src0->data + ids[i3]*(src0->nb[3])); // {d_state, dim, nh, ns} + float * s = ( float *) (( char *) dst->data + i3*(src0->nb[3]) + s_off); // {d_state, dim, nh, ns} + + for (int i2 = 0; i2 < nt; ++i2) { + const float * x = (const float *) ((const char *) src1->data + i2*(src1->nb[2]) + i3*(src1->nb[3])); // {dim, nh, nt, ns} + const float * dt = (const float *) ((const char *) src2->data + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {nh, nt, ns} + const float * A = (const float *) ((const char *) src3->data); // {d_state, nh} or {1, nh} + const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[2]) + i3*(src4->nb[3])); // {d_state, ng, nt, ns} + const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[2]) + i3*(src5->nb[3])); // {d_state, ng, nt, ns} + float * y = ( float *) (( char *) dst->data + i2*(nh*nr*sizeof(float)) + i3*(nt*nh*nr*sizeof(float))); // {dim, nh, nt, ns} + + if (src3->ne[0] == 1) { + // Mamba-2 has a scalar decay factor per head; dA can be outside the state-wise loop + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dA = expf(dt_soft_plus * A[h]); + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; + float sumf = 0.0f; +#if defined(GGML_SIMD) + const int np = (nc & ~(GGML_F32_STEP - 1)); - // rows per thread - const int dr = (nr + nth - 1)/nth; + GGML_F32_VEC sum[GGML_F32_ARR] = { GGML_F32_VEC_ZERO }; - // row range for this thread - const int ir0 = dr*ith; - const int ir1 = MIN(ir0 + dr, nr); - const int ir = ir1 - ir0; + GGML_F32_VEC adA = GGML_F32_VEC_SET1(dA); + GGML_F32_VEC axdt = GGML_F32_VEC_SET1(x_dt); - for (int i3 = 0; i3 < n_s; ++i3) { - for (int i2 = 0; i2 < n_t; ++i2) { - const float * s0 = (const float *) ((const char *) src0->data + ir0*(src0->nb[1]) + i3*(src0->nb[2])); // {d_state, d_inner, n_s} - const float * x = (const float *) ((const char *) src1->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - const float * dt = (const float *) ((const char *) src2->data + ir0*(src2->nb[0]) + i2*(src2->nb[1]) + i3*(src2->nb[2])); // {d_inner, n_t, n_s} - const float * A = (const float *) ((const char *) src3->data + ir0*(src3->nb[1])); // {d_state, d_inner} - const float * B = (const float *) ((const char *) src4->data + i2*(src4->nb[1]) + i3*(src4->nb[2])); // {d_state, n_t, n_s} - const float * C = (const float *) ((const char *) src5->data + i2*(src5->nb[1]) + i3*(src5->nb[2])); // {d_state, n_t, n_s} - float * y = ( float *) (( char *) dst->data + ir0*(src1->nb[0]) + i2*(src1->nb[1]) + i3*(src1->nb[2])); // {d_inner, n_t, n_s} - float * s = ( float *) (( char *) dst->data + ir0*(src0->nb[1]) + i3*(src0->nb[2]) + src1->nb[3]); // {d_state, d_inner, n_s} - - // use the output as the source for the next token-wise iterations - if (i2 > 0) { s0 = s; } + GGML_F32_VEC ax[GGML_F32_ARR]; + GGML_F32_VEC ay[GGML_F32_ARR]; + GGML_F32_VEC az[GGML_F32_ARR]; - // d_inner - for (int i1 = 0; i1 < ir; ++i1) { - // ref: https://github.com/state-spaces/mamba/blob/34076d664838588a3c97727b263478ab9f621a07/mamba_ssm/ops/triton/selective_state_update.py#L78 - float dt_soft_plus = dt[i1] <= 20.0f ? log1pf(expf(dt[i1])) : dt[i1]; - float x_dt = x[i1] * dt_soft_plus; - float sumf = 0.0f; - // d_state - for (int i0 = 0; i0 < nc; ++i0) { - int i = i0 + i1*nc; - // state = prev_state * dA + dB * x - float state = (s0[i] * expf(dt_soft_plus * A[i])) + (B[i0] * x_dt); - // y = rowwise_dotprod(state, C) - sumf += state * C[i0]; - s[i] = state; + for (int i = 0; i < np; i += GGML_F32_STEP) { + for (int j = 0; j < GGML_F32_ARR; j++) { + ax[j] = GGML_F32_VEC_LOAD(s0 + i + j*GGML_F32_EPR + ii*nc); + ay[j] = GGML_F32_VEC_LOAD(B + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + az[j] = GGML_F32_VEC_LOAD(C + i + j*GGML_F32_EPR + (h & (ng - 1))*nc); + + ax[j] = GGML_F32_VEC_MUL(ax[j], adA); + ay[j] = GGML_F32_VEC_MUL(ay[j], axdt); + + ax[j] = GGML_F32_VEC_ADD(ax[j], ay[j]); + + sum[j] = GGML_F32_VEC_FMA(sum[j], ax[j], az[j]); + + GGML_F32_VEC_STORE(s + i + j*GGML_F32_EPR + ii*nc, ax[j]); + } + } + + // reduce sum0..sum3 to sum0 + GGML_F32_VEC_REDUCE(sumf, sum); +#else + const int np = 0; +#endif + // d_state + for (int i0 = np; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * dA) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; + } + } + } else { + // Mamba-1 has an element-wise decay factor for the states + + // n_head + for (int h = ih0; h < ih1; ++h) { + // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 + const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + + // dim + for (int i1 = 0; i1 < nr; ++i1) { + const int ii = i1 + h*nr; + const float x_dt = x[ii] * dt_soft_plus; + float sumf = 0.0f; + // NOTE: can't really use GGML_SIMD here because d_state is usually 16 + // and also because expf is used within the loop. + // d_state + for (int i0 = 0; i0 < nc; ++i0) { + const int i = i0 + ii*nc; + const int ig = i0 + (h & (ng - 1))*nc; + // state = prev_state * dA + dB * x + const float state = (s0[i] * expf(dt_soft_plus * A[i0 + h*nc])) + (B[ig] * x_dt); + // y = rowwise_dotprod(state, C) + sumf += state * C[ig]; + s[i] = state; + } + y[ii] = sumf; + } } - y[i1] = sumf; } + // use the output as the source when it's not the first token-wise iteration + s0 = s; } } } diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 8721b272de6a9..75362686860de 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -463,26 +463,25 @@ typedef struct { typedef struct { int64_t d_state; int64_t d_inner; + int64_t n_head; + int64_t n_group; int64_t n_seq_tokens; int64_t n_seqs; - uint64_t nb00; uint64_t nb01; uint64_t nb02; - uint64_t nb10; + uint64_t nb03; uint64_t nb11; uint64_t nb12; uint64_t nb13; - uint64_t nb20; uint64_t nb21; uint64_t nb22; - uint64_t nb30; uint64_t nb31; - uint64_t nb40; uint64_t nb41; uint64_t nb42; - uint64_t nb50; + uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t nb53; } ggml_metal_kargs_ssm_scan; typedef struct { diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index d92392edb7eb1..7d6377c790903 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -189,6 +189,7 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte GGML_METAL_KERNEL_TYPE_NORM, GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, + GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, @@ -1125,6 +1126,7 @@ @implementation GGMLMetalClass GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_CONV_F32, ssm_conv_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32, ssm_scan_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP, ssm_scan_f32_group, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV6_F32, rwkv_wkv6_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RWKV_WKV7_F32, rwkv_wkv7_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, has_simdgroup_reduction); @@ -2586,71 +2588,91 @@ static bool ggml_metal_encode_node( struct ggml_tensor * src3 = node->src[3]; struct ggml_tensor * src4 = node->src[4]; struct ggml_tensor * src5 = node->src[5]; + struct ggml_tensor * src6 = node->src[6]; GGML_ASSERT(src3); GGML_ASSERT(src4); GGML_ASSERT(src5); + GGML_ASSERT(src6); size_t offs_src3 = 0; size_t offs_src4 = 0; size_t offs_src5 = 0; + size_t offs_src6 = 0; id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; id id_src4 = src4 ? ggml_metal_get_buffer(src4, &offs_src4) : nil; id id_src5 = src5 ? ggml_metal_get_buffer(src5, &offs_src5) : nil; + id id_src6 = src6 ? ggml_metal_get_buffer(src6, &offs_src6) : nil; - const int64_t ne30 = src3->ne[0]; GGML_UNUSED(ne30); + const int64_t ne30 = src3->ne[0]; const int64_t ne31 = src3->ne[1]; GGML_UNUSED(ne31); - const uint64_t nb30 = src3->nb[0]; + const uint64_t nb30 = src3->nb[0]; GGML_UNUSED(nb30); const uint64_t nb31 = src3->nb[1]; const int64_t ne40 = src4->ne[0]; GGML_UNUSED(ne40); - const int64_t ne41 = src4->ne[1]; GGML_UNUSED(ne41); + const int64_t ne41 = src4->ne[1]; const int64_t ne42 = src4->ne[2]; GGML_UNUSED(ne42); + const int64_t ne43 = src4->ne[3]; GGML_UNUSED(ne43); - const uint64_t nb40 = src4->nb[0]; + const uint64_t nb40 = src4->nb[0]; GGML_UNUSED(nb40); const uint64_t nb41 = src4->nb[1]; const uint64_t nb42 = src4->nb[2]; + const uint64_t nb43 = src4->nb[3]; const int64_t ne50 = src5->ne[0]; GGML_UNUSED(ne50); const int64_t ne51 = src5->ne[1]; GGML_UNUSED(ne51); const int64_t ne52 = src5->ne[2]; GGML_UNUSED(ne52); + const int64_t ne53 = src5->ne[3]; GGML_UNUSED(ne53); - const uint64_t nb50 = src5->nb[0]; + const uint64_t nb50 = src5->nb[0]; GGML_UNUSED(nb50); const uint64_t nb51 = src5->nb[1]; const uint64_t nb52 = src5->nb[2]; + const uint64_t nb53 = src5->nb[3]; + + const int64_t ne60 = src6->ne[0]; GGML_UNUSED(ne60); + + const uint64_t nb60 = src6->nb[0]; GGML_UNUSED(nb60); const int64_t d_state = ne00; const int64_t d_inner = ne01; - const int64_t n_seq_tokens = ne11; - const int64_t n_seqs = ne02; + const int64_t n_head = ne02; + const int64_t n_group = ne41; + const int64_t n_seq_tokens = ne12; + const int64_t n_seqs = ne13; + + id pipeline = nil; - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + if (ne30 == 1) { + // Mamba-2 + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32_GROUP].pipeline; + } else { + pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + } ggml_metal_kargs_ssm_scan args = { - /*.d_state =*/ d_state, - /*.d_inner =*/ d_inner, + /*.d_state =*/ d_state, + /*.d_inner =*/ d_inner, + /*.n_head =*/ n_head, + /*.n_group =*/ n_group, /*.n_seq_tokens =*/ n_seq_tokens, - /*.n_seqs =*/ n_seqs, - /*.nb00 =*/ nb00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb12 =*/ nb12, - /*.nb13 =*/ nb13, - /*.nb20 =*/ nb20, - /*.nb21 =*/ nb21, - /*.nb22 =*/ nb22, - /*.nb30 =*/ nb30, - /*.nb31 =*/ nb31, - /*.nb40 =*/ nb40, - /*.nb41 =*/ nb41, - /*.nb42 =*/ nb42, - /*.nb50 =*/ nb50, - /*.nb51 =*/ nb51, - /*.nb52 =*/ nb52, + /*.n_seqs =*/ n_seqs, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.nb21 =*/ nb21, + /*.nb22 =*/ nb22, + /*.nb31 =*/ nb31, + /*.nb41 =*/ nb41, + /*.nb42 =*/ nb42, + /*.nb43 =*/ nb43, + /*.nb51 =*/ nb51, + /*.nb52 =*/ nb52, + /*.nb53 =*/ nb53, }; [encoder setComputePipelineState:pipeline]; @@ -2660,10 +2682,17 @@ static bool ggml_metal_encode_node( [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; [encoder setBuffer:id_src4 offset:offs_src4 atIndex:4]; [encoder setBuffer:id_src5 offset:offs_src5 atIndex:5]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:6]; - [encoder setBytes:&args length:sizeof(args) atIndex:7]; + [encoder setBuffer:id_src6 offset:offs_src6 atIndex:6]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:7]; + [encoder setBytes:&args length:sizeof(args) atIndex:8]; - [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne30 == 1) { + // Mamba-2 + [encoder dispatchThreadgroups:MTLSizeMake(d_inner, n_head, n_seqs) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } else { + GGML_ASSERT(d_inner == 1); + [encoder dispatchThreadgroups:MTLSizeMake(n_head, n_seqs, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_RWKV_WKV6: { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 9f4147e93974d..4e50efdee41ca 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -1256,7 +1256,7 @@ kernel void kernel_ssm_conv_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32 +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part kernel void kernel_ssm_scan_f32( device const void * src0, device const void * src1, @@ -1264,46 +1264,119 @@ kernel void kernel_ssm_scan_f32( device const void * src3, device const void * src4, device const void * src5, + device const void * src6, device float * dst, constant ggml_metal_kargs_ssm_scan & args, uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]], uint3 ntg[[threads_per_threadgroup]]) { - const int64_t ir = tgpig.x; - const int64_t i3 = tgpig.y; + const int64_t i1 = 0; + const int64_t ir = tgpig.x; // current head + const int64_t i3 = tgpig.y; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); const int64_t nc = args.d_state; - // const int64_t nr = args.d_inner; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; const int64_t n_t = args.n_seq_tokens; - // const int64_t n_s = args.n_seqs; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb01 + i3*args.nb02); - device const float * x = (device const float *) ((device const char *) src1 + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); - device const float * dt = (device const float *) ((device const char *) src2 + ir*args.nb20 + i2*args.nb21 + i3*args.nb22); - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * B = (device const float *) ((device const char *) src4 + i2*args.nb41 + i3*args.nb42); - device const float * C = (device const float *) ((device const char *) src5 + i2*args.nb51 + i3*args.nb52); - device float * y = (device float *) ((device char *) dst + ir*args.nb10 + i2*args.nb11 + i3*args.nb12); // TODO: do not use src1 strides - device float * s = (device float *) ((device char *) dst + ir*args.nb01 + i3*args.nb02 + args.nb13); - - if (i2 > 0) { - s0 = s; - } - - // i1 == 0 - float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {d_state, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; float sumf = 0.0f; for (int64_t i0 = 0; i0 < nc; ++i0) { - int64_t i = i0; - float state = (s0[i] * exp(dt_soft_plus * A[i])) + (B[i0] * x_dt); + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); sumf += state * C[i0]; s[i] = state; } y[0] = sumf; + + // recurse + s0 = s; + } +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part +// TODO: optimize (e.g. by parallelizing over d_state) +kernel void kernel_ssm_scan_f32_group( + device const void * src0, + device const void * src1, + device const void * src2, + device const void * src3, + device const void * src4, + device const void * src5, + device const void * src6, + device float * dst, + constant ggml_metal_kargs_ssm_scan & args, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i1 = tgpig.x; + const int64_t ir = tgpig.y; // current head + const int64_t i3 = tgpig.z; // current seq + + const uint64_t nb00 = sizeof(float); + const uint64_t nb10 = sizeof(float); + const uint64_t nb20 = sizeof(float); + + const int64_t nc = args.d_state; + const int64_t nr = args.d_inner; + const int64_t nh = args.n_head; + const int64_t ng = args.n_group; + const int64_t n_t = args.n_seq_tokens; + + const int64_t s_off = nr * nh * n_t * args.n_seqs * sizeof(float); + + device const int32_t * ids = (device const int32_t *) src6; + + device const float * s0 = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); + device float * s = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); + + for (int64_t i2 = 0; i2 < n_t; ++i2) { + device const float * x = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i2*args.nb12 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *) ((device const char *) src2 + ir*nb20 + i2*args.nb21 + i3*args.nb22); // {nh, nt, ns} + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} + device const float * B = (device const float *) ((device const char *) src4 + (ir & (ng - 1))*args.nb41 + i2*args.nb42 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *) ((device const char *) src5 + (ir & (ng - 1))*args.nb51 + i2*args.nb52 + i3*args.nb53); // {d_state, ng, nt, ns} + device float * y = (device float *) ((device char *) dst + (i1 + ir*(nr) + i2*(nh*nr) + i3*(n_t*nh*nr))*nb00); // {dim, nh, nt, ns} + + const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; + const float x_dt = x[0] * dt_soft_plus; + const float dA = exp(dt_soft_plus * A[0]); + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc; ++i0) { + const int64_t i = i0 + i1*nc; + const float state = (s0[i] * dA) + (B[i0] * x_dt); + sumf += state * C[i0]; + s[i] = state; + } + + y[0] = sumf; + + // recurse + s0 = s; } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 7654ae1779b1d..05832b038bdcb 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -4498,7 +4498,6 @@ struct ggml_tensor * ggml_ssm_conv( const int64_t n_s = sx->ne[2]; // TODO: maybe support other strides than 1? - // FIXME: this is always true? GGML_ASSERT(sx->ne[0] == d_conv - 1 + n_t); GGML_ASSERT(sx->ne[1] == d_inner); GGML_ASSERT(n_t >= 0); @@ -4521,36 +4520,49 @@ struct ggml_tensor * ggml_ssm_scan( struct ggml_tensor * dt, struct ggml_tensor * A, struct ggml_tensor * B, - struct ggml_tensor * C) { + struct ggml_tensor * C, + struct ggml_tensor * ids) { GGML_ASSERT(ggml_is_contiguous(s)); - GGML_ASSERT(ggml_is_contiguous(x)); GGML_ASSERT(ggml_is_contiguous(dt)); GGML_ASSERT(ggml_is_contiguous(A)); - GGML_ASSERT(ggml_is_matrix(A)); - GGML_ASSERT(ggml_is_3d(B)); - GGML_ASSERT(ggml_is_3d(s)); + GGML_ASSERT(x->nb[0] == ggml_type_size(x->type)); GGML_ASSERT(B->nb[0] == ggml_type_size(B->type)); GGML_ASSERT(C->nb[0] == ggml_type_size(C->type)); - GGML_ASSERT(ggml_are_same_shape(x, dt)); + GGML_ASSERT(x->nb[1] == x->ne[0]*x->nb[0]); + GGML_ASSERT(B->nb[1] == B->ne[0]*B->nb[0]); + GGML_ASSERT(C->nb[1] == C->ne[0]*C->nb[0]); GGML_ASSERT(ggml_are_same_shape(B, C)); + GGML_ASSERT(ids->type == GGML_TYPE_I32); { const int64_t d_state = s->ne[0]; - const int64_t d_inner = s->ne[1]; - const int64_t n_seq_tokens = x->ne[1]; - const int64_t n_seqs = x->ne[2]; - - GGML_ASSERT(s->ne[2] == n_seqs); - GGML_ASSERT(x->ne[0] == d_inner); - GGML_ASSERT(A->ne[0] == d_state); - GGML_ASSERT(A->ne[1] == d_inner); + const int64_t head_dim = x->ne[0]; + const int64_t n_head = x->ne[1]; + const int64_t n_seq_tokens = x->ne[2]; + const int64_t n_seqs = x->ne[3]; + + GGML_ASSERT(dt->ne[0] == n_head); + GGML_ASSERT(dt->ne[1] == n_seq_tokens); + GGML_ASSERT(dt->ne[2] == n_seqs); + GGML_ASSERT(ggml_is_3d(dt)); + GGML_ASSERT(s->ne[1] == head_dim); + GGML_ASSERT(s->ne[2] == n_head); GGML_ASSERT(B->ne[0] == d_state); - GGML_ASSERT(B->ne[1] == n_seq_tokens); - GGML_ASSERT(B->ne[2] == n_seqs); + GGML_ASSERT(B->ne[2] == n_seq_tokens); + GGML_ASSERT(B->ne[3] == n_seqs); + GGML_ASSERT(ids->ne[0] == n_seqs); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(A->ne[1] == n_head); + GGML_ASSERT(ggml_is_matrix(A)); + + if (A->ne[0] != 1) { + // Mamba-1 has more granular decay factors + GGML_ASSERT(A->ne[0] == d_state); + } } // concatenated y + ssm_states - struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + ggml_nelements(s)); + struct ggml_tensor * result = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ggml_nelements(x) + s->ne[0]*s->ne[1]*s->ne[2]*ids->ne[0]); result->op = GGML_OP_SSM_SCAN; result->src[0] = s; @@ -4559,6 +4571,7 @@ struct ggml_tensor * ggml_ssm_scan( result->src[3] = A; result->src[4] = B; result->src[5] = C; + result->src[6] = ids; return result; } diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index a2540bd93fd91..b1e8a000d88e6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -164,6 +164,7 @@ class SSM: INNER_SIZE = "{arch}.ssm.inner_size" STATE_SIZE = "{arch}.ssm.state_size" TIME_STEP_RANK = "{arch}.ssm.time_step_rank" + GROUP_COUNT = "{arch}.ssm.group_count" DT_B_C_RMS = "{arch}.ssm.dt_b_c_rms" class WKV: @@ -297,6 +298,7 @@ class MODEL_ARCH(IntEnum): RWKV7 = auto() ARWKV7 = auto() MAMBA = auto() + MAMBA2 = auto() XVERSE = auto() COMMAND_R = auto() COHERE2 = auto() @@ -380,6 +382,7 @@ class MODEL_TENSOR(IntEnum): SSM_DT = auto() SSM_A = auto() SSM_D = auto() + SSM_NORM = auto() SSM_OUT = auto() TIME_MIX_W0 = auto() TIME_MIX_W1 = auto() @@ -554,6 +557,7 @@ class MODEL_TENSOR(IntEnum): MODEL_ARCH.RWKV7: "rwkv7", MODEL_ARCH.ARWKV7: "arwkv7", MODEL_ARCH.MAMBA: "mamba", + MODEL_ARCH.MAMBA2: "mamba2", MODEL_ARCH.XVERSE: "xverse", MODEL_ARCH.COMMAND_R: "command-r", MODEL_ARCH.COHERE2: "cohere2", @@ -637,6 +641,7 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_DT: "blk.{bid}.ssm_dt", MODEL_TENSOR.SSM_A: "blk.{bid}.ssm_a", MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", + MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1", @@ -1527,6 +1532,19 @@ class MODEL_TENSOR(IntEnum): MODEL_TENSOR.SSM_D, MODEL_TENSOR.SSM_OUT, ], + MODEL_ARCH.MAMBA2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.SSM_IN, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_D, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_OUT, + ], MODEL_ARCH.XVERSE: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2245,6 +2263,7 @@ class VisionProjectorType: KEY_SSM_INNER_SIZE = Keys.SSM.INNER_SIZE KEY_SSM_STATE_SIZE = Keys.SSM.STATE_SIZE KEY_SSM_TIME_STEP_RANK = Keys.SSM.TIME_STEP_RANK +KEY_SSM_GROUP_COUNT = Keys.SSM.GROUP_COUNT KEY_SSM_DT_B_C_RMS = Keys.SSM.DT_B_C_RMS # tokenization diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index a30c49e32b351..a6977dacbddf4 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -842,6 +842,9 @@ def add_ssm_state_size(self, value: int) -> None: def add_ssm_time_step_rank(self, value: int) -> None: self.add_uint32(Keys.SSM.TIME_STEP_RANK.format(arch=self.arch), value) + def add_ssm_group_count(self, value: int) -> None: + self.add_uint32(Keys.SSM.GROUP_COUNT.format(arch=self.arch), value) + def add_ssm_dt_b_c_rms(self, value: bool) -> None: self.add_bool(Keys.SSM.DT_B_C_RMS.format(arch=self.arch), value) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 2f6326104ffa7..df50b5229f2d6 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -457,7 +457,7 @@ class TensorNameMap: "encoder.layers.{bid}.norm2", # nomic-bert "transformer.decoder_layer.{bid}.rms_norm_3", # Grok "encoder.layer.{bid}.mlp.layernorm", # jina-bert-v2 - "encoder.layer.{bid}.layer_norm_2" # jina-v2-code + "encoder.layer.{bid}.layer_norm_2", # jina-v2-code ), MODEL_TENSOR.SSM_IN: ( @@ -490,6 +490,10 @@ class TensorNameMap: "backbone.layers.{bid}.mixer.D", ), + MODEL_TENSOR.SSM_NORM: ( + "backbone.layers.{bid}.mixer.norm", # mamba2 + ), + MODEL_TENSOR.SSM_OUT: ( "model.layers.{bid}.out_proj", "backbone.layers.{bid}.mixer.out_proj", diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index f2bc8ca768502..49f829827fd72 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -43,6 +43,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_GEMMA3, "gemma3" }, { LLM_ARCH_STARCODER2, "starcoder2" }, { LLM_ARCH_MAMBA, "mamba" }, + { LLM_ARCH_MAMBA2, "mamba2" }, { LLM_ARCH_XVERSE, "xverse" }, { LLM_ARCH_COMMAND_R, "command-r" }, { LLM_ARCH_COHERE2, "cohere2" }, @@ -164,6 +165,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_SSM_INNER_SIZE, "%s.ssm.inner_size" }, { LLM_KV_SSM_STATE_SIZE, "%s.ssm.state_size" }, { LLM_KV_SSM_TIME_STEP_RANK, "%s.ssm.time_step_rank" }, + { LLM_KV_SSM_GROUP_COUNT, "%s.ssm.group_count" }, { LLM_KV_SSM_DT_B_C_RMS, "%s.ssm.dt_b_c_rms" }, { LLM_KV_WKV_HEAD_SIZE, "%s.wkv.head_size" }, @@ -926,6 +928,22 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, }, }, + { + LLM_ARCH_MAMBA2, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, + { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, + { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, + { LLM_TENSOR_SSM_A, "blk.%d.ssm_a" }, + { LLM_TENSOR_SSM_D, "blk.%d.ssm_d" }, + { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, + { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, + }, + }, { LLM_ARCH_XVERSE, { @@ -1635,6 +1653,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_CONV1D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_CONV}}, {LLM_TENSOR_SSM_A, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_SSM_SCAN}}, {LLM_TENSOR_SSM_D, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_SSM_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LERP_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_TIME_MIX_LN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_CHANNEL_MIX_LERP_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 41a023da3da6e..7bfbb44ce17ba 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -47,6 +47,7 @@ enum llm_arch { LLM_ARCH_GEMMA3, LLM_ARCH_STARCODER2, LLM_ARCH_MAMBA, + LLM_ARCH_MAMBA2, LLM_ARCH_XVERSE, LLM_ARCH_COMMAND_R, LLM_ARCH_COHERE2, @@ -168,6 +169,7 @@ enum llm_kv { LLM_KV_SSM_CONV_KERNEL, LLM_KV_SSM_STATE_SIZE, LLM_KV_SSM_TIME_STEP_RANK, + LLM_KV_SSM_GROUP_COUNT, LLM_KV_SSM_DT_B_C_RMS, LLM_KV_WKV_HEAD_SIZE, @@ -269,6 +271,7 @@ enum llm_tensor { LLM_TENSOR_SSM_DT, LLM_TENSOR_SSM_A, LLM_TENSOR_SSM_D, + LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_TIME_MIX_W0, LLM_TENSOR_TIME_MIX_W1, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a88b9a5ff90da..2b2235a661c58 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1066,12 +1066,12 @@ int llama_context::decode(llama_batch & inp_batch) { } } + // set to total number of outputs in the batch, for use in llama_get_logits_ith + n_outputs = n_outputs_all; + // make the outputs have the same order they had in the user-provided batch // note: this is mostly relevant for recurrent models atm if (!sorted_output) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - GGML_ASSERT((size_t) n_outputs == out_ids.size()); // TODO: is there something more efficient which also minimizes swaps? @@ -1086,12 +1086,12 @@ int llama_context::decode(llama_batch & inp_batch) { if (j_min == i) { continue; } std::swap(out_ids[i], out_ids[j_min]); if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { + for (int32_t k = 0; k < n_vocab; k++) { std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); } } if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { + for (int64_t k = 0; k < n_embd; k++) { std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); } } @@ -1101,11 +1101,10 @@ int llama_context::decode(llama_batch & inp_batch) { output_ids[out_ids[i]] = i; } } + // sorted, so no need for the indices anymore + out_ids.clear(); } - // set to total number of outputs in the batch, for use in llama_get_logits_ith - n_outputs = n_outputs_all; - // wait for the computation to finish (automatically done when obtaining the model output) //synchronize(); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0da4e7d2b0547..f2513bd551f18 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -284,23 +284,23 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_copy(i); - } - } -} + const uint32_t cell_id = i + kv_self->head; -void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); + const auto & kv_cell = kv_self->cells[cell_id]; - const int64_t n_kv = kv_self->n; + int32_t src = kv_cell.src0; - if (s_mask) { - GGML_ASSERT(ggml_backend_buffer_is_host(s_mask->buffer)); - float * data = (float *) s_mask->data; + // prevent out-of-bound sources + if (src < 0) { + GGML_ASSERT(kv_self->rs_z >= 0); // Need a valid zero-ed cell as a source + src = kv_self->rs_z; + } + if ((uint32_t) src >= kv_self->size) { + // ignore out-of-bound sources + src = cell_id; + } - // clear unused states - for (int i = 0; i < n_kv; ++i) { - data[i] = kv_self->s_mask(i); + data[i] = src; } } } @@ -1093,23 +1093,6 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { return cur; } -ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_recurrent * kv_self = static_cast(memory); - - auto inp = std::make_unique(kv_self); - - const auto n_kv = kv_self->n; - - auto & cur = inp->s_mask; - - cur = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, 1, n_kv); - ggml_set_input(cur); - - res->add_input(std::move(inp)); - - return cur; -} - ggml_tensor * llm_graph_context::build_inp_cross_embd() const { auto inp = std::make_unique(cross); @@ -1550,43 +1533,48 @@ ggml_tensor * llm_graph_context::build_attn( return cur; } -ggml_tensor * llm_graph_context::build_copy_mask_state( +ggml_tensor * llm_graph_context::build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, int32_t n_state, - int32_t n_seqs) const { + int32_t n_seqs, + bool avoid_copies) const { const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_kv = kv_self->n; const auto kv_head = kv_self->head; + const auto rs_zero = kv_self->rs_z; ggml_tensor * states = ggml_reshape_2d(ctx0, s, n_state, kv_self->size); - // copy states - // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv - // this shrinks the tensors's ne[1] to n_kv - states = ggml_get_rows(ctx0, states, state_copy); - - // clear states of sequences which are starting at the beginning of this batch - // FIXME: zero-out NANs? - states = ggml_mul(ctx0, states, state_mask); + // Clear a single state which will then be copied to the other cleared states. + // Note that this is a no-op when the view is zero-sized. + ggml_tensor * state_zero = ggml_view_1d(ctx0, states, n_state*(rs_zero >= 0), rs_zero*states->nb[1]*(rs_zero >= 0)); + ggml_build_forward_expand(gf, ggml_scale_inplace(ctx0, state_zero, 0)); // copy states which won't be changed further (between n_seqs and n_kv) + ggml_tensor * states_extra = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_kv - n_seqs, n_seqs*state_copy->nb[0])); ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, states, n_state*(n_kv - n_seqs), (n_seqs )*n_state*ggml_element_size(states)), - ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + states_extra, + ggml_view_1d(ctx0, s, n_state*(n_kv - n_seqs), (kv_head + n_seqs)*n_state*ggml_element_size(s)))); + + if (!avoid_copies) { + // copy states + // NOTE: assuming the copy destinations are ALL contained between kv_head and kv_head + n_kv + // this shrinks the tensors's ne[1] to n_kv + states = ggml_get_rows(ctx0, states, ggml_view_1d(ctx0, state_copy, n_seqs, 0)); + // the part of the states that will be used and modified + states = ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); + } - // the part of the states that will be used and modified - return ggml_view_2d(ctx0, states, n_state, n_seqs, states->nb[1], 0); + return states; } ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const llama_kv_cache_recurrent * kv_self = static_cast(memory); @@ -1597,8 +1585,8 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * token_shift_all = kv_self->k_l[il]; - ggml_tensor * token_shift = build_copy_mask_state( - gf, token_shift_all, state_copy, state_mask, + ggml_tensor * token_shift = build_recurrent_state( + gf, token_shift_all, state_copy, hparams.n_embd_k_s(), n_seqs); token_shift = ggml_reshape_3d(ctx0, token_shift, hparams.n_embd, token_shift_count, n_seqs); diff --git a/src/llama-graph.h b/src/llama-graph.h index 5b404366dc145..d126ea109029a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -197,18 +197,6 @@ class llm_graph_input_s_copy : public llm_graph_input_i { const llama_kv_cache_recurrent * kv_self; }; -class llm_graph_input_s_mask : public llm_graph_input_i { -public: - llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_s_mask() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * s_mask; // F32 [1, n_kv] - - const llama_kv_cache_recurrent * kv_self; -}; - class llm_graph_input_cross_embd : public llm_graph_input_i { public: llm_graph_input_cross_embd( @@ -491,7 +479,6 @@ struct llm_graph_context { ggml_tensor * build_inp_mean() const; ggml_tensor * build_inp_cls() const; ggml_tensor * build_inp_s_copy() const; - ggml_tensor * build_inp_s_mask() const; ggml_tensor * build_inp_cross_embd() const; ggml_tensor * build_inp_pos_bucket_enc() const; @@ -562,18 +549,17 @@ struct llm_graph_context { // recurrent // - ggml_tensor * build_copy_mask_state( + ggml_tensor * build_recurrent_state( ggml_cgraph * gf, ggml_tensor * s, ggml_tensor * state_copy, - ggml_tensor * state_mask, int32_t n_state, - int32_t n_seqs) const; + int32_t n_seqs, + bool avoid_copies = false) const; ggml_tensor * build_rwkv_token_shift_load( ggml_cgraph * gf, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const; diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index 90dfe7a7fcc00..1d77154e55e7f 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -57,7 +57,8 @@ uint32_t llama_hparams::n_embd_k_s() const { // TODO: maybe support other convolution strides than 1 // NOTE: since the first column of the conv_state is shifted out each time, it's not actually needed - return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * ssm_d_inner; + // Corresponds to Mamba's conv_states size + return (ssm_d_conv > 0 ? ssm_d_conv - 1 : 0) * (ssm_d_inner + 2*ssm_n_group*ssm_d_state); } uint32_t llama_hparams::n_embd_v_s() const { diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 7ee6a5b75ad1e..1ed8553ed49eb 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -101,6 +101,7 @@ struct llama_hparams { uint32_t ssm_d_inner = 0; uint32_t ssm_d_state = 0; uint32_t ssm_dt_rank = 0; + uint32_t ssm_n_group = 0; bool ssm_dt_b_c_rms = false; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 3dcad65bb6a85..99dd20b68fd73 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -469,8 +469,6 @@ bool llama_kv_cache_unified::find_slot( head = 0; } - // otherwise, one cell per token. - if (n_tokens > size) { LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); return false; @@ -1894,6 +1892,18 @@ bool llama_kv_cache_recurrent::find_slot( } } + // Find first to-be-cleared cell + rs_z = -1; + for (int i = min; i <= max; ++i) { + if (rs_z < 0 && cells[i].src == -1) { + rs_z = i; + } + // Stage the source ids for all used cells to allow correct seq_* behavior + // and still make these values available when setting the inputs + cells[i].src0 = cells[i].src; + cells[i].src = i; + } + // allow getting the range of used cells, from head to head + n head = min; n = max - min + 1; @@ -1928,47 +1938,8 @@ llama_pos llama_kv_cache_recurrent::get_pos_max() const { } bool llama_kv_cache_recurrent::get_can_shift() const { - return false; -} - -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; -} - -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[cell_id]); - - float res = (float) (cell.src >= 0); - - // only clear once - if (cell.src < 0) { - cell.src = cell_id; - } - - return res; + // shifting is trivial, the recurrent states don't care about the absolute position + return true; } uint32_t llama_kv_cache_recurrent::cell_max() const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index bf3b4b6a4430f..7a23c0ba5160d 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -266,7 +266,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { public: struct kv_cell { llama_pos pos = -1; - int32_t src = -1; // used to copy states + int32_t src = -1; // used to know where states should be copied from + int32_t src0 = -1; // like src, but used when setting the inputs (allowing to copy once) int32_t tail = -1; std::set seq_id; @@ -334,10 +335,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool get_can_shift() const override; - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; - // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; @@ -353,6 +350,9 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // computed before each graph build uint32_t n = 0; + // first zero-ed state + int32_t rs_z = -1; + std::vector cells; std::vector k_l; // per layer @@ -374,8 +374,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { std::vector ranges; } pending; - ggml_type type_k = GGML_TYPE_F16; - ggml_type type_v = GGML_TYPE_F16; + ggml_type type_k = GGML_TYPE_F32; + ggml_type type_v = GGML_TYPE_F32; std::vector ctxs; std::vector bufs; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index e163de76a7596..1efa94cea709b 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -195,23 +195,27 @@ static bool weight_buft_supported(const llama_hparams & hparams, ggml_tensor * w } break; case GGML_OP_SSM_CONV: { - // FIXME - ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, 12345, w->ne[1], 6789); + const int64_t n_seq_tokens = 512; + const int64_t n_seqs = 3; + ggml_tensor * conv_x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, w->ne[0] - 1 + n_seq_tokens, w->ne[1], n_seqs); op_tensor = ggml_ssm_conv(ctx, conv_x, w); } break; case GGML_OP_SSM_SCAN: { - // FIXME - const int64_t d_state = w->ne[0]; - const int64_t d_inner = w->ne[1]; + // w is ssm_a, which is used to distinguish Mamba-1 and Mamba-2 + const int64_t d_state = w->ne[0] == 1 ? hparams.ssm_d_state : w->ne[0]; + const int64_t n_head = w->ne[1]; + const int64_t head_dim = hparams.ssm_d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; const int64_t n_seq_tokens = 512; - const int64_t n_seqs = 1; - ggml_tensor * s = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, d_inner, n_seqs); - ggml_tensor * x = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_inner, n_seq_tokens, n_seqs); - ggml_tensor * B = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - ggml_tensor * C = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, d_state, n_seq_tokens, n_seqs); - op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C); + const int64_t n_seqs = 3; + ggml_tensor * s = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, n_head, n_seq_tokens, n_seqs); + ggml_tensor * B = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + op_tensor = ggml_ssm_scan(ctx, s, x, dt, w, B, C, ids); } break; case GGML_OP_RWKV_WKV6: { @@ -1007,6 +1011,38 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAMBA2: + { + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + switch (hparams.n_layer) { + case 24: + switch (hparams.n_embd) { + case 768: type = LLM_TYPE_SMALL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 48: + switch (hparams.n_embd) { + case 1024: type = LLM_TYPE_MEDIUM; break; + case 1536: type = LLM_TYPE_LARGE; break; + case 2048: type = LLM_TYPE_XL; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + case 64: + switch (hparams.n_embd) { + case 2560: type = LLM_TYPE_3B; break; + case 4096: type = LLM_TYPE_7B; break; + default: type = LLM_TYPE_UNKNOWN; + } break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_XVERSE: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -2917,6 +2953,54 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {d_state, d_inner}, 0); layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {d_inner}, 0); + // out_proj + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); + } + } break; + case LLM_ARCH_MAMBA2: + { + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t n_group = hparams.ssm_n_group; + const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_head; + + // only an expansion factor of 2 is supported for now + GGML_ASSERT(2 * n_embd == d_inner); + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + { + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed, duplicated to allow offloading + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, llama_model_loader::TENSOR_DUPLICATED); + } + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + // norm + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0); + + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0); + layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, 0); + + layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_head}, 0); + + // no "weight" suffix for these + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_head}, 0); + layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_head}, 0); + + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0); + // out_proj layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0); } @@ -4303,10 +4387,14 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); + } + + if (arch == LLM_ARCH_MAMBA || arch == LLM_ARCH_MAMBA2) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner); LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state); LLAMA_LOG_INFO("%s: ssm_dt_rank = %u\n", __func__, hparams.ssm_dt_rank); + LLAMA_LOG_INFO("%s: ssm_n_group = %u\n", __func__, hparams.ssm_n_group); LLAMA_LOG_INFO("%s: ssm_dt_b_c_rms = %d\n", __func__, hparams.ssm_dt_b_c_rms); } @@ -8670,7 +8758,6 @@ struct llm_build_mamba : public llm_graph_context { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); for (int il = 0; il < n_layer; ++il) { // norm @@ -8679,8 +8766,11 @@ struct llm_build_mamba : public llm_graph_context { LLM_NORM_RMS, il); cb(cur, "attn_norm", il); - //cur = build_mamba_layer(gf, cur, state_copy, state_mask, il); - cur = build_mamba_layer(gf, cur, state_copy, state_mask, ubatch, il); + if (model.arch == LLM_ARCH_MAMBA2) { + cur = build_mamba2_layer(gf, cur, state_copy, ubatch, il); + } else { + cur = build_mamba_layer(gf, cur, state_copy, ubatch, il); + } if (il == n_layer - 1) { // skip computing output for unused tokens @@ -8716,12 +8806,10 @@ struct llm_build_mamba : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - // TODO: split ggml_tensor * build_mamba_layer( ggml_cgraph * gf, ggml_tensor * cur, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const llama_kv_cache_recurrent * kv_self = static_cast(memory); @@ -8732,6 +8820,8 @@ struct llm_build_mamba : public llm_graph_context { const int64_t d_inner = hparams.ssm_d_inner; const int64_t d_state = hparams.ssm_d_state; const int64_t dt_rank = hparams.ssm_dt_rank; + const int64_t n_head = d_inner; + const int64_t head_dim = 1; const int64_t n_seqs = ubatch.n_seqs; // Some variants of Mamba arch (e.g. FalconMamba do apply layer norm on B and Dt layers) const bool ssm_dt_b_c_rms = hparams.ssm_dt_b_c_rms; @@ -8748,14 +8838,10 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * ssm_states_all = kv_self->v_l[il]; // (ab)using the KV cache to store the states - ggml_tensor * conv = build_copy_mask_state( - gf, conv_states_all, state_copy, state_mask, - hparams.n_embd_k_s(), n_seqs); + ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner, n_seqs); - ggml_tensor * ssm = build_copy_mask_state( - gf, ssm_states_all, state_copy, state_mask, - hparams.n_embd_v_s(), n_seqs); - ssm = ggml_reshape_3d(ctx0, ssm, d_state, d_inner, n_seqs); + ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); + ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_self->size); // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); @@ -8804,8 +8890,8 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * x_db = build_lora_mm(model.layers[il].ssm_x, x); // split ggml_tensor * dt = ggml_view_3d(ctx0, x_db, dt_rank, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], 0); - ggml_tensor * B = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); - ggml_tensor * C = ggml_view_3d(ctx0, x_db, d_state, n_seq_tokens, n_seqs, x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); + ggml_tensor * B = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*dt_rank); + ggml_tensor * C = ggml_view_4d(ctx0, x_db, d_state, /* n_group */ 1, n_seq_tokens, n_seqs, d_state*x_db->nb[0], x_db->nb[1], x_db->nb[2], ggml_element_size(x_db)*(dt_rank+d_state)); // Some Mamba variants (e.g. FalconMamba) apply RMS norm in B, C & Dt layers if (ssm_dt_b_c_rms) { @@ -8818,32 +8904,156 @@ struct llm_build_mamba : public llm_graph_context { dt = build_lora_mm(model.layers[il].ssm_dt, dt); dt = ggml_add(ctx0, dt, model.layers[il].ssm_dt_b); + cur = x; + x = ggml_reshape_4d(ctx0, x, head_dim, n_head, n_seq_tokens, n_seqs); + + ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); // Custom operator to optimize the parallel associative scan // as described in the Annex D of the Mamba paper. // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} - ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C); + ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); // store last states ggml_build_forward_expand(gf, ggml_cpy(ctx0, - ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]), + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, x->nb[3]*x->ne[3]), ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); - ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[1], x->nb[2], 0); + ggml_tensor * y = ggml_view_3d(ctx0, y_ssm, d_inner, n_seq_tokens, n_seqs, x->nb[2], x->nb[3], 0); + + // TODO: skip computing output earlier for unused tokens + + y = ggml_add(ctx0, y, ggml_mul(ctx0, cur, model.layers[il].ssm_d)); + y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} + cur = build_lora_mm(model.layers[il].ssm_out, y); + } + + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); + // cb(cur, "mamba_out", il); + + return cur; + } + + ggml_tensor * build_mamba2_layer( + ggml_cgraph * gf, + ggml_tensor * cur, + ggml_tensor * state_copy, + const llama_ubatch & ubatch, + int il) const { + const llama_kv_cache_recurrent * kv_self = static_cast(memory); + + const auto kv_head = kv_self->head; + + const int64_t d_conv = hparams.ssm_d_conv; + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t d_state = hparams.ssm_d_state; + const int64_t n_head = hparams.ssm_dt_rank; + const int64_t head_dim = d_inner / n_head; + const int64_t n_group = hparams.ssm_n_group; + const int64_t n_seqs = ubatch.n_seqs; + + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + ggml_tensor * conv_states_all = kv_self->k_l[il]; + ggml_tensor * ssm_states_all = kv_self->v_l[il]; + + // (ab)using the KV cache to store the states + ggml_tensor * conv = build_recurrent_state(gf, conv_states_all, state_copy, hparams.n_embd_k_s(), n_seqs); + conv = ggml_reshape_3d(ctx0, conv, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs); + ggml_tensor * ssm = build_recurrent_state(gf, ssm_states_all, state_copy, hparams.n_embd_v_s(), n_seqs, true); + ssm = ggml_reshape_4d(ctx0, ssm, d_state, head_dim, n_head, kv_self->size); + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + // d_in_proj = 2 * self.d_inner + 2 * self.ngroups * self.d_state + self.nheads + + // {n_embd, d_in_proj} @ {n_embd, n_seq_tokens, n_seqs} => {d_in_proj, n_seq_tokens, n_seqs} + ggml_tensor * zxBCdt = build_lora_mm(model.layers[il].ssm_in, cur); + + // split the above in three + ggml_tensor * z = ggml_view_4d(ctx0, zxBCdt, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*zxBCdt->nb[0], zxBCdt->nb[1], zxBCdt->nb[2], 0); + ggml_tensor * xBC = ggml_view_3d(ctx0, zxBCdt, d_inner + 2*n_group*d_state, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], d_inner*ggml_element_size(zxBCdt)); + ggml_tensor * dt = ggml_view_3d(ctx0, zxBCdt, n_head, n_seq_tokens, n_seqs, zxBCdt->nb[1], zxBCdt->nb[2], (2*d_inner + 2*n_group*d_state)*ggml_element_size(zxBCdt)); + + // conv + { + // => {d_conv - 1 + n_seq_tokens, d_inner + 2*n_group*d_state, n_seqs} + ggml_tensor * conv_x = ggml_concat(ctx0, conv, ggml_transpose(ctx0, xBC), 0); + + // copy last (d_conv - 1) columns back into the state cache + ggml_tensor * last_conv = ggml_view_3d(ctx0, conv_x, d_conv - 1, d_inner + 2*n_group*d_state, n_seqs, conv_x->nb[1], conv_x->nb[2], n_seq_tokens*(conv_x->nb[0])); + + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, last_conv, + ggml_view_1d(ctx0, conv_states_all, + (d_conv - 1)*(d_inner + 2*n_group*d_state)*(n_seqs), + kv_head*(d_conv - 1)*(d_inner + 2*n_group*d_state)*ggml_element_size(conv_states_all)))); + + // 1D convolution + // The equivalent is to make a self-overlapping view of conv_x + // over d_conv columns at each stride in the 3rd dimension, + // then element-wise multiply that with the conv1d weight, + // then sum the elements of each row, + // (the last two steps are a dot product over rows (also doable with mul_mat)) + // then permute away the ne[0] dimension, + // and then you're left with the resulting x tensor. + // For simultaneous sequences, all sequences need to have the same length. + xBC = ggml_ssm_conv(ctx0, conv_x, model.layers[il].ssm_conv1d); + + // bias + xBC = ggml_add(ctx0, xBC, model.layers[il].ssm_conv1d_b); + + xBC = ggml_silu(ctx0, xBC); + } + + // ssm + { + // These correspond to V K Q in SSM/attention duality + ggml_tensor * x = ggml_view_4d(ctx0, xBC, head_dim, n_head, n_seq_tokens, n_seqs, head_dim*xBC->nb[0], xBC->nb[1], xBC->nb[2], 0); + ggml_tensor * B = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], d_inner*ggml_element_size(xBC)); + ggml_tensor * C = ggml_view_4d(ctx0, xBC, d_state, n_group, n_seq_tokens, n_seqs, d_state*xBC->nb[0], xBC->nb[1], xBC->nb[2], (d_inner + n_group*d_state)*ggml_element_size(xBC)); + + // {n_head, n_seq_tokens, n_seqs} + dt = ggml_add(ctx0, ggml_cont(ctx0, dt), model.layers[il].ssm_dt_b); + + ggml_tensor * ssm_ids = ggml_view_1d(ctx0, state_copy, n_seqs, 0); + // TODO: use semistructured matrices to implement state-space duality + // => {d_inner, n_seq_tokens, n_seqs} and {d_state, d_inner, n_seqs} + ggml_tensor * y_ssm = ggml_ssm_scan(ctx0, ssm, x, dt, model.layers[il].ssm_a, B, C, ssm_ids); + + // store last states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, + ggml_view_1d(ctx0, y_ssm, d_state*d_inner*n_seqs, ggml_nelements(x)*x->nb[0]), + ggml_view_1d(ctx0, ssm_states_all, d_state*d_inner*n_seqs, kv_head*d_state*d_inner*ggml_element_size(ssm_states_all)))); + + ggml_tensor * y = ggml_view_4d(ctx0, y_ssm, head_dim, n_head, n_seq_tokens, n_seqs, x->nb[1], n_head*x->nb[1], n_seq_tokens*n_head*x->nb[1], 0); // TODO: skip computing output earlier for unused tokens - // {d_inner, n_seq_tokens, n_seqs} * {d_inner} => {d_inner, n_seq_tokens, n_seqs} y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); y = ggml_mul(ctx0, y, ggml_silu(ctx0, ggml_cont(ctx0, z))); + // grouped RMS norm + y = ggml_reshape_4d(ctx0, y, d_inner / n_group, n_group, n_seq_tokens, n_seqs); + y = build_norm(y, model.layers[il].ssm_norm, NULL, LLM_NORM_RMS, il); + y = ggml_reshape_3d(ctx0, y, d_inner, n_seq_tokens, n_seqs); + // {d_inner, n_embd} @ {d_inner, n_seq_tokens, n_seqs} => {n_embd, n_seq_tokens, n_seqs} cur = build_lora_mm(model.layers[il].ssm_out, y); } // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} cur = ggml_reshape_2d(ctx0, cur, cur->ne[0], n_seq_tokens * n_seqs); - //cb(cur, "mamba_out", il); + // cb(cur, "mamba_out", il); return cur; } @@ -11469,7 +11679,6 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { const llama_kv_cache_recurrent * kv_self = static_cast(memory); @@ -11593,8 +11802,8 @@ struct llm_build_rwkv6_base : public llm_graph_context { k = ggml_sub(ctx0, k, ggml_mul(ctx0, k, w)); } - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_self->v_l[il], state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output; @@ -11650,7 +11859,6 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11661,7 +11869,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -11677,7 +11885,7 @@ struct llm_build_rwkv6 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -11748,7 +11956,6 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -11759,7 +11966,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -11772,7 +11979,7 @@ struct llm_build_rwkv6qwen2 : public llm_build_rwkv6_base { 1 ); - cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, state_mask, ubatch, il); + cur = build_rwkv6_time_mix(gf, att_norm, x_prev, state_copy, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -11864,7 +12071,6 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor * cur, ggml_tensor * x_prev, ggml_tensor * state_copy, - ggml_tensor * state_mask, ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { @@ -11947,8 +12153,8 @@ struct llm_build_rwkv7_base : public llm_graph_context { v = ggml_reshape_3d(ctx0, v, head_size, head_count, n_tokens); a = ggml_reshape_3d(ctx0, a, head_size, head_count, n_tokens); - ggml_tensor * wkv_state = build_copy_mask_state( - gf, kv_self->v_l[il], state_copy, state_mask, + ggml_tensor * wkv_state = build_recurrent_state( + gf, kv_self->v_l[il], state_copy, hparams.n_embd_v_s(), n_seqs); ggml_tensor * wkv_output = ggml_rwkv_wkv7(ctx0, r, w, k, v, ggml_neg(ctx0, kk), ggml_mul(ctx0, kk, a), wkv_state); @@ -12006,7 +12212,6 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = build_norm(inpL, model.tok_norm, model.tok_norm_b, LLM_NORM, -1); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12017,7 +12222,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_shift = ggml_view_3d(ctx0, token_shift, n_embd, 1, n_seqs, token_shift->nb[1], token_shift->nb[2], 0); @@ -12033,7 +12238,7 @@ struct llm_build_rwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); cb(ffn_inp, "ffn_inp", il); @@ -12100,7 +12305,6 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = build_inp_embd(model.tok_embd); ggml_tensor * state_copy = build_inp_s_copy(); - ggml_tensor * state_mask = build_inp_s_mask(); const auto n_embd = hparams.n_embd; const auto n_seq_tokens = ubatch.n_seq_tokens; @@ -12111,7 +12315,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_seq_tokens, n_seqs); ggml_tensor * token_shift = build_rwkv_token_shift_load( - gf, state_copy, state_mask, ubatch, il + gf, state_copy, ubatch, il ); ggml_tensor * att_norm = build_norm(inpL, layer->attn_norm, layer->attn_norm_b, LLM_NORM_RMS, il); @@ -12124,7 +12328,7 @@ struct llm_build_arwkv7 : public llm_build_rwkv7_base { 1 ); - cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, state_mask, v_first, ubatch, il); + cur = build_rwkv7_time_mix(gf, att_norm, x_prev, state_copy, v_first, ubatch, il); token_shift = ggml_view_3d(ctx0, att_norm, n_embd, 1, n_seqs, att_norm->nb[1], att_norm->nb[2], (n_seq_tokens-1)*n_embd*ggml_element_size(att_norm)); ggml_build_forward_expand(gf, build_rwkv_token_shift_store(token_shift, ubatch, il)); @@ -12833,6 +13037,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, switch (arch) { case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: case LLM_ARCH_RWKV7: @@ -12999,6 +13204,7 @@ llm_graph_result_ptr llama_model::build_graph( llm = std::make_unique(*this, params, gf); } break; case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: { llm = std::make_unique(*this, params, gf); } break; @@ -13225,6 +13431,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_REFACT: case LLM_ARCH_BLOOM: case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_T5: case LLM_ARCH_T5ENCODER: @@ -13391,12 +13598,15 @@ llama_token llama_model_decoder_start_token(const llama_model * model) { bool llama_model_is_recurrent(const llama_model * model) { switch (model->arch) { - case LLM_ARCH_MAMBA: return true; - case LLM_ARCH_RWKV6: return true; - case LLM_ARCH_RWKV6QWEN2: return true; - case LLM_ARCH_RWKV7: return true; - case LLM_ARCH_ARWKV7: return true; - default: return false; + case LLM_ARCH_MAMBA: + case LLM_ARCH_MAMBA2: + case LLM_ARCH_RWKV6: + case LLM_ARCH_RWKV6QWEN2: + case LLM_ARCH_RWKV7: + case LLM_ARCH_ARWKV7: + return true; + default: + return false; } } diff --git a/src/llama-model.h b/src/llama-model.h index 4c7e7a335b4e2..c24c133e7abcc 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -165,6 +165,7 @@ struct llama_layer { struct ggml_tensor * ffn_sub_norm = nullptr; struct ggml_tensor * attn_norm_cross = nullptr; struct ggml_tensor * attn_norm_enc = nullptr; + struct ggml_tensor * ssm_norm = nullptr; // attention struct ggml_tensor * wq = nullptr; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 767498ef39dd5..25915da78aeda 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1854,28 +1854,58 @@ struct test_ssm_scan : public test_case { const ggml_type type; const int64_t d_state; - const int64_t d_inner; + const int64_t head_dim; + const int64_t n_head; + const int64_t n_group; const int64_t n_seq_tokens; const int64_t n_seqs; std::string vars() override { - return VARS_TO_STR5(type, d_state, d_inner, n_seq_tokens, n_seqs); + return VARS_TO_STR7(type, d_state, head_dim, n_head, n_group, n_seq_tokens, n_seqs); } test_ssm_scan(ggml_type type = GGML_TYPE_F32, - int64_t d_state = 32, int64_t d_inner = 32, int64_t n_seq_tokens = 32, int64_t n_seqs = 32) - : type(type), d_state(d_state), d_inner(d_inner), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} + int64_t d_state = 32, + int64_t head_dim = 1, // non-zero for Mamba-2 + int64_t n_head = 32, + int64_t n_group = 1, + int64_t n_seq_tokens = 32, + int64_t n_seqs = 32) + : type(type), d_state(d_state), head_dim(head_dim), n_head(n_head), n_group(n_group), n_seq_tokens(n_seq_tokens), n_seqs(n_seqs) {} ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * s = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, n_seqs, 1 }.data()); - ggml_tensor * x = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * dt = ggml_new_tensor(ctx, type, 4, std::vector{ d_inner, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * A = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, d_inner, 1 , 1 }.data()); - ggml_tensor * B = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * C = ggml_new_tensor(ctx, type, 4, std::vector{ d_state, n_seq_tokens, n_seqs, 1 }.data()); - ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C); + ggml_tensor * s = ggml_new_tensor_4d(ctx, type, d_state, head_dim, n_head, n_seqs); + ggml_tensor * x = ggml_new_tensor_4d(ctx, type, head_dim, n_head, n_seq_tokens, n_seqs); + ggml_tensor * dt = ggml_new_tensor_3d(ctx, type, n_head, n_seq_tokens, n_seqs); + ggml_tensor * A = ggml_new_tensor_2d(ctx, type, (head_dim > 1) ? 1 : d_state, n_head); + ggml_tensor * B = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * C = ggml_new_tensor_4d(ctx, type, d_state, n_group, n_seq_tokens, n_seqs); + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_seqs); + ggml_tensor * out = ggml_ssm_scan(ctx, s, x, dt, A, B, C, ids); return out; } + + // similar to test_mul_mat_id + void initialize_tensors(ggml_context * ctx) override { + std::random_device rd; + std::default_random_engine rng(rd()); + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + if (ggml_is_view_op(t->op)) { continue; } + // ids + for (int64_t r = 0; r < ggml_nrows(t); r++) { + std::vector data(t->ne[0]); + for (int i = 0; i < t->ne[0]; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), rng); + ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t)); + } + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_RWKV_WKV6 @@ -4128,7 +4158,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {8, 1536, 1, 1}, {4, 1536, 1, 1})); test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 1536, 4, 1}, {4, 1536, 1, 1})); - test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1024, 32, 4)); + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 16, 1, 1024, 1, 32, 4)); // Mamba-1 + test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 32, 32, 2, 32, 4)); // Mamba-2 test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 1, 1)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 1));