From 22bda4864485060a5270f62411a80f5b7d2410dd Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Mon, 7 Apr 2025 16:50:58 +0300 Subject: [PATCH 01/29] kv-cache : serparate recurrent vs non-recurrent impl (wip) ggml-ci --- src/llama-context.cpp | 133 +++-- src/llama-context.h | 2 +- src/llama-graph.cpp | 16 +- src/llama-graph.h | 9 +- src/llama-kv-cache.cpp | 1265 +++++++++++++++++++++++++++++++--------- src/llama-kv-cache.h | 188 +++++- src/llama-model.cpp | 8 +- 7 files changed, 1240 insertions(+), 381 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5a2eef9b784a1..b511185d77da3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -179,24 +179,37 @@ llama_context::llama_context( // init the memory module // TODO: for now, always create a unified KV cache if (!hparams.vocab_only) { - kv_self.reset(static_cast(model.create_memory())); + uint32_t kv_size = 0; + ggml_type type_k = params.type_k; + ggml_type type_v = params.type_v; - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + if (!llama_model_is_recurrent(&model)) { + //kv_self.reset(static_cast(model.create_memory())); + auto * kv = static_cast(model.create_memory()); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams)); - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + kv_size = cparams.n_ctx; + type_k = params.type_k; + type_v = params.type_v; + + kv_self.reset(kv); + } else { + auto * kv = static_cast(model.create_memory()); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - if (llama_model_is_recurrent(&model)) { // Mamba needs at least as many KV cells as there are sequences kept at any time kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states + + kv_self.reset(kv); } GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); @@ -305,7 +318,7 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - kv_self->n = kv_self->size; + kv_self->set_full(); cross.v_embd.clear(); @@ -546,7 +559,9 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( //GGML_ASSERT(kv_self->size == n_ctx); - auto inp = std::make_unique(kv_self.get()); + const auto * kv = static_cast(kv_self.get()); + + auto inp = std::make_unique(kv); inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); ggml_set_input(inp->k_shift); @@ -562,13 +577,13 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); + ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il); ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_head_kv, kv_self->size, - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), + ggml_view_3d(ctx0, kv->k_l[il], + n_embd_head_k, n_head_kv, kv->size, + ggml_row_size(kv->k_l[il]->type, n_embd_head_k), + ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), 0); ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); @@ -586,9 +601,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag( ggml_cgraph * gf) const { auto res = std::make_unique(); + auto * kv = static_cast(kv_self.get()); + const auto & hparams = model.hparams; - const auto & ids = kv_self->defrag_info.ids; + const auto & ids = kv->defrag_info.ids; #if 0 // CPU defrag @@ -678,40 +695,40 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag( const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], + ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il], n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i)); + ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i)); - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], + ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il], n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id)); + ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id)); ggml_tensor * view_v_src; ggml_tensor * view_v_dst; if (cparams.flash_attn) { // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], + view_v_src = ggml_view_2d(ctx0, kv->v_l[il], n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i)); + ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i)); - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], + view_v_dst = ggml_view_2d(ctx0, kv->v_l[il], n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id)); + ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa), + ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id)); } else { - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], + view_v_src = ggml_view_2d(ctx0, kv->v_l[il], nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, i)); + ggml_row_size(kv->v_l[il]->type, kv->size), + ggml_row_size(kv->v_l[il]->type, i)); - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], + view_v_dst = ggml_view_2d(ctx0, kv->v_l[il], nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, id)); + ggml_row_size(kv->v_l[il]->type, kv->size), + ggml_row_size(kv->v_l[il]->type, id)); } ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); @@ -728,13 +745,11 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag( } void llama_context::kv_self_update() { - auto & kv = kv_self; - bool need_reserve = false; - if (kv->has_shift) { - if (!kv->get_can_shift()) { - GGML_ABORT("The current context does not support K-shift"); + if (kv_self->get_has_shift()) { + if (!kv_self->get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); @@ -757,6 +772,8 @@ void llama_context::kv_self_update() { } { + auto * kv = static_cast(kv_self.get()); + kv->has_shift = false; for (uint32_t i = 0; i < kv->size; ++i) { @@ -766,9 +783,11 @@ void llama_context::kv_self_update() { } // defragment the KV cache if needed - if (kv->do_defrag) { + if (kv_self->get_do_defrag()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + auto * kv = static_cast(kv_self.get()); + if (kv->defrag_prepare(graph_max_nodes())) { ggml_backend_sched_reset(sched.get()); @@ -797,7 +816,7 @@ void llama_context::kv_self_update() { uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // simulate full KV cache - kv_self->n = kv_self->size; + kv_self->set_full(); llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; @@ -1017,8 +1036,8 @@ int llama_context::encode(llama_batch & inp_batch) { } // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; @@ -1182,8 +1201,8 @@ int llama_context::decode(llama_batch & inp_batch) { } // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -1238,8 +1257,10 @@ int llama_context::decode(llama_batch & inp_batch) { const bool logits_all = n_outputs_all == n_tokens_all; + const bool is_recurrent = llama_model_is_recurrent(&model); + sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self->recurrent, + /* simple_split */ !is_recurrent, /* logits_all */ logits_all); // reserve output buffer @@ -1258,7 +1279,7 @@ int llama_context::decode(llama_batch & inp_batch) { const auto & n_ubatch = cparams.n_ubatch; - if (kv_self->recurrent) { + if (is_recurrent) { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) ubatch = sbatch.split_seq(cparams.n_ubatch); @@ -1296,17 +1317,19 @@ int llama_context::decode(llama_batch & inp_batch) { return 1; } - if (!kv_self->recurrent) { + if (!is_recurrent) { + auto * kv = static_cast(kv_self.get()); + // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self->get_padding(cparams); - kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); + const uint32_t pad = kv->get_padding(cparams); + kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad))); + + //printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head); } } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -1446,10 +1469,12 @@ int llama_context::decode(llama_batch & inp_batch) { //synchronize(); // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { + if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) { + auto * kv = static_cast(kv_self.get()); + // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; + const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { diff --git a/src/llama-context.h b/src/llama-context.h index 5457f077c15bf..c39c193da6f20 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,7 +200,7 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr kv_self; + std::unique_ptr kv_self; // TODO: remove bool logits_all = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fabb9ca237653..cd3896be05c97 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -288,7 +288,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; // prevent out-of-bound sources if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { @@ -321,7 +321,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; + llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; data[i] = (float) (kv_cell.src >= 0); @@ -1105,7 +1105,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1122,7 +1122,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1436,8 +1436,6 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - GGML_ASSERT(!kv_self->recurrent); - const auto kv_head = kv_self->head; GGML_ASSERT(kv_self->size == n_ctx); @@ -1587,7 +1585,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_kv = kv_self->n; const auto kv_head = kv_self->head; @@ -1619,7 +1617,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; @@ -1640,7 +1638,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c8d32192784..23397a76fd3bd 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -19,6 +19,7 @@ struct llama_cparams; class llama_memory_i; class llama_kv_cache_unified; +class llama_kv_cache_recurrent; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -186,26 +187,26 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_cross_embd : public llm_graph_input_i { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 7c9d46d8119b3..283171b2ebd63 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -11,6 +11,10 @@ #include #include +// +// llama_kv_cache_unified +// + llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { } @@ -25,9 +29,10 @@ bool llama_kv_cache_unified::init( has_shift = false; - recurrent = llama_model_is_recurrent(&model); - v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent; + GGML_ASSERT(!llama_model_is_recurrent(&model)); + + v_trans = !cparams.flash_attn; + can_shift = true; LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); @@ -135,6 +140,14 @@ int32_t llama_kv_cache_unified::get_used_cells() const { return used; } +bool llama_kv_cache_unified::get_has_shift() const { + return has_shift; +} + +bool llama_kv_cache_unified::get_do_defrag() const { + return do_defrag; +} + size_t llama_kv_cache_unified::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -144,7 +157,7 @@ size_t llama_kv_cache_unified::total_size() const { return size; } -llama_pos llama_kv_cache_unified::pos_max() const { +llama_pos llama_kv_cache_unified::get_pos_max() const { llama_pos pos_max = -1; for (const auto & cell : cells) { pos_max = std::max(pos_max, cell.pos); @@ -179,35 +192,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (recurrent) { - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const llama_kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - return true; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].pos >= p0 && cells[i].pos < p1) { if (seq_id < 0) { @@ -254,34 +238,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id p1 = std::numeric_limits::max(); } - if (recurrent) { - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - llama_kv_cell & tail_src = cells[seq_id_src]; - llama_kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.delta = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } - - return; - } - // otherwise, this is the KV of a Transformer-like model head = 0; @@ -296,9 +252,10 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { uint32_t new_head = size; for (uint32_t i = 0; i < size; ++i) { - if (recurrent && (llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } + // TODO: remove tail + //if (recurrent && (llama_seq_id) i != seq_id) { + // cells[i].tail = -1; + //} if (!cells[i].has_seq_id(seq_id)) { if (cells[i].pos >= 0) { @@ -344,20 +301,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - } - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -400,21 +343,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } - - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -441,9 +369,7 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { } void llama_kv_cache_unified::defrag() { - if (!recurrent) { - do_defrag = true; - } + do_defrag = true; } void llama_kv_cache_unified::restore() { @@ -451,12 +377,6 @@ void llama_kv_cache_unified::restore() { return; } - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - seq_rm(-1, -1, -1); - return; - } - uint32_t new_head = size; for (auto & range : pending.ranges) { @@ -481,11 +401,6 @@ void llama_kv_cache_unified::restore() { } void llama_kv_cache_unified::commit() { - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - return; - } - if (pending.ranges.empty()) { LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); @@ -511,169 +426,6 @@ bool llama_kv_cache_unified::find_slot( head = 0; } - if (recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. - - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); - - int32_t min = size - 1; - int32_t max = 0; - - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return false; - } - if (j > 0) { - llama_kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } - } - -#ifndef NDEBUG - { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - llama_kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } - for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } - } - } -#endif - - // find next empty cell - uint32_t next_empty_cell = head; - - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } - - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cells[dst_id]; - llama_kv_cell & src_cell = cells[src_id]; - - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); - - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } - - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cells[cell_id]; - - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } - } - - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); - - // sanity check - return n >= n_seqs; - } - // otherwise, one cell per token. if (n_tokens > size) { @@ -745,6 +497,10 @@ uint32_t llama_kv_cache_unified::cell_max() const { return 0; } +void llama_kv_cache_unified::set_full() { + n = size; +} + size_t llama_kv_cache_unified::size_k_bytes() const { size_t size_k_bytes = 0; @@ -1133,15 +889,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell } cell.seq_id.insert(seq_id); - - if (recurrent) { - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } } } @@ -1149,18 +896,964 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell used = cell_count; } - if (recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; + return true; +} + +bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } } } return true; } -bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell_count) { +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { +} + +bool llama_kv_cache_recurrent::init( + const llama_model & model, + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) { + GGML_UNUSED(cparams); + + const int32_t n_layer = hparams.n_layer; + + GGML_ASSERT(llama_model_is_recurrent(&model)); + + LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + this->type_k = type_k; + this->type_v = type_v; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft; + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } else { + buft = ggml_backend_cpu_buffer_type(); + } + + LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, + i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); + return false; + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); + return false; + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + return true; +} + +int32_t llama_kv_cache_recurrent::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_kv_cache_recurrent::get_used_cells() const { + return used; +} + +bool llama_kv_cache_recurrent::get_has_shift() const { + return false; +} + +bool llama_kv_cache_recurrent::get_do_defrag() const { + return false; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +llama_pos llama_kv_cache_recurrent::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); + } + + return pos_max; +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const llama_kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + llama_kv_cell & tail_src = cells[seq_id_src]; + llama_kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + llama_kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.delta = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + llama_kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + llama_kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = 0; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +void llama_kv_cache_recurrent::defrag() { + LLAMA_LOG_ERROR("%s: not supported\n", __func__); +} + +void llama_kv_cache_recurrent::restore() { + if (pending.ranges.empty()) { + return; + } + + seq_rm(-1, -1, -1); +} + +void llama_kv_cache_recurrent::commit() { + pending.ranges.clear(); +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +bool llama_kv_cache_recurrent::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + return false; + } + if (j > 0) { + llama_kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + llama_kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + llama_kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + llama_kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + llama_kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + llama_kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + llama_kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + llama_kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + llama_kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + llama_kv_cell & dst_cell = cells[dst_id]; + llama_kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + llama_kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +uint32_t llama_kv_cache_recurrent::get_padding(const llama_cparams & cparams) const { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + +uint32_t llama_kv_cache_recurrent::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const llama_kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +void llama_kv_cache_recurrent::set_full() { + n = size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + commit(); + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { uint32_t v_trans; uint32_t n_layer; io.read_to(&v_trans, sizeof(v_trans)); @@ -1174,7 +1867,7 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); return false; } - if (v_trans != (bool) v_trans) { + if (false != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 56c74035ae1b9..a2d88c9cca2fe 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -15,17 +15,44 @@ struct llama_hparams; struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { + virtual ~llama_kv_cache() = default; + using llama_memory_i::llama_memory_i; + // TODO: become constructor + virtual bool init( + const llama_model & model, // TODO: do not reference the model + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) = 0; + virtual void restore() = 0; // call if batch processing fails - restores the cache state virtual void commit() = 0; // call after successful batch processing - clears any pending state virtual int32_t get_n_tokens() const = 0; virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + virtual bool get_has_shift() const = 0; + virtual bool get_do_defrag() const = 0; + + virtual llama_pos get_pos_max() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } + + virtual bool find_slot(const llama_ubatch & batch) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual void set_full() = 0; + + virtual size_t size_k_bytes() const = 0; + virtual size_t size_v_bytes() const = 0; + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; struct llama_kv_cache_guard { @@ -74,12 +101,6 @@ class llama_kv_cache_unified : public llama_kv_cache { std::function get_rope_factors; }; - llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs); - - virtual ~llama_kv_cache_unified() = default; - // TODO: become constructor bool init( const llama_model & model, // TODO: do not reference the model @@ -87,21 +108,30 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_type type_k, ggml_type type_v, uint32_t kv_size, - bool offload); + bool offload) override; + + llama_kv_cache_unified( + const llama_hparams & hparams, + callbacks cbs); + + ~llama_kv_cache_unified() = default; int32_t get_n_tokens() const override; int32_t get_used_cells() const override; + bool get_has_shift() const override; + bool get_do_defrag() const override; + size_t total_size() const; // TODO: better data structures to reduce the cost of this operation - llama_pos pos_max() const; + llama_pos get_pos_max() const override; void clear() override; void defrag() override; - virtual void restore() override; - virtual void commit() override; + void restore() override; + void commit() override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -117,7 +147,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch) override; // TODO: maybe not needed uint32_t get_padding(const llama_cparams & cparams) const; @@ -125,8 +155,10 @@ class llama_kv_cache_unified : public llama_kv_cache { // find how many cells are currently in use uint32_t cell_max() const; - size_t size_k_bytes() const; - size_t size_v_bytes() const; + void set_full() override; + + size_t size_k_bytes() const override; + size_t size_v_bytes() const override; // defrag @@ -151,8 +183,8 @@ class llama_kv_cache_unified : public llama_kv_cache { // state write/load - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; // members @@ -163,9 +195,6 @@ class llama_kv_cache_unified : public llama_kv_cache { bool has_shift = false; bool do_defrag = false; - // TODO: remove this and implement llama_kv_cache_recurrent instead - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token - bool v_trans = true; // the value tensor is transposed bool can_shift = false; @@ -198,11 +227,124 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified -//class llama_kv_cache_recurrent : public llama_kv_cache_unified { -//public: -// using llama_kv_cache_unified::llama_kv_cache_unified; -//}; +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + // can be used to query data from the model if needed + struct callbacks { + std::function get_rope_factors; + }; + + llama_kv_cache_recurrent( + const llama_hparams & hparams, + callbacks cbs); + + ~llama_kv_cache_recurrent() = default; + + // TODO: become constructor + bool init( + const llama_model & model, // TODO: do not reference the model + const llama_cparams & cparams, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size, + bool offload) override; + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + bool get_has_shift() const override; + bool get_do_defrag() const override; + + size_t total_size() const; + + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; + + void clear() override; + void defrag() override; + + void restore() override; + void commit() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + bool get_can_shift() const override; + + // find an empty slot of size "n_tokens" in the cache + // updates the cache head + // Note: On success, it's important that cache.head points + // to the first cell of the slot. + bool find_slot(const llama_ubatch & batch) override; + + // TODO: maybe not needed + uint32_t get_padding(const llama_cparams & cparams) const; + + // find how many cells are currently in use + uint32_t cell_max() const; + + void set_full() override; + + size_t size_k_bytes() const override; + size_t size_v_bytes() const override; + + // commit/restore cache + + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // members + + const llama_hparams & hparams; + + callbacks cbs; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_impl also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + std::vector ctxs; + std::vector bufs; + + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; + + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; + // // kv cache view diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 51092a128c5c6..a3ed8fa599cb7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -8711,7 +8711,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto kv_head = kv_self->head; @@ -11459,7 +11459,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11855,7 +11855,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12825,7 +12825,7 @@ llama_memory_i * llama_model::create_memory() const { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: { - res = new llama_kv_cache_unified(hparams, { + res = new llama_kv_cache_recurrent(hparams, { /*.get_rope_factors =*/ nullptr }); } break; From 81457990035d0780bc1d0a7e59f6098b72dc1182 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Apr 2025 16:49:06 +0300 Subject: [PATCH 02/29] kv-cache : init -> contructor + add llama_memory_params ggml-ci --- src/llama-context.cpp | 36 +++++++++------ src/llama-kv-cache.cpp | 100 ++++++++++------------------------------- src/llama-kv-cache.h | 62 ++++++++----------------- src/llama-memory.h | 14 ++++++ src/llama-model.cpp | 69 ++++++++++++++++++++-------- src/llama-model.h | 2 +- 6 files changed, 130 insertions(+), 153 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b511185d77da3..c27cb559e10f4 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -184,12 +184,9 @@ llama_context::llama_context( ggml_type type_v = params.type_v; if (!llama_model_is_recurrent(&model)) { - //kv_self.reset(static_cast(model.create_memory())); - auto * kv = static_cast(model.create_memory()); - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv->get_padding(cparams)); + cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams)); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -197,26 +194,37 @@ llama_context::llama_context( type_k = params.type_k; type_v = params.type_v; - kv_self.reset(kv); - } else { - auto * kv = static_cast(model.create_memory()); + llama_memory_params params_mem = { + /*.type_k =*/ type_k, + /*.type_v =*/ type_v, + /*.v_trans =*/ !cparams.flash_attn, + /*.offload_kqv =*/ cparams.offload_kqv, + /*.kv_size =*/ kv_size, + }; - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + auto * kv = static_cast(model.create_memory(params_mem)); + kv_self.reset(kv); + } else { // Mamba needs at least as many KV cells as there are sequences kept at any time kv_size = std::max((uint32_t) 1, params.n_seq_max); // it's probably best to keep as much precision as possible for the states type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - kv_self.reset(kv); - } + llama_memory_params params_mem = { + /*.type_k =*/ type_k, + /*.type_v =*/ type_v, + /*.v_trans =*/ false, // unused + /*.offload_kqv =*/ params.offload_kqv, + /*.kv_size =*/ kv_size, + }; - GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); - GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); + auto * kv = static_cast(model.create_memory(params_mem)); - if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { - throw std::runtime_error("failed to initialize self-attention cache"); + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + + kv_self.reset(kv); } { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 283171b2ebd63..0f3c349941195 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -15,27 +15,22 @@ // llama_kv_cache_unified // -llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { -} +llama_kv_cache_unified::llama_kv_cache_unified( + const llama_hparams & hparams, + callbacks cbs, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) { -bool llama_kv_cache_unified::init( - const llama_model & model, - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) { const int32_t n_layer = hparams.n_layer; has_shift = false; - GGML_ASSERT(!llama_model_is_recurrent(&model)); - - v_trans = !cparams.flash_attn; can_shift = true; - LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); head = 0; size = kv_size; @@ -79,25 +74,11 @@ bool llama_kv_cache_unified::init( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft; - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + ggml_backend_buffer_type_t buft = cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to create ggml context for kv cache"); } ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); @@ -115,15 +96,12 @@ bool llama_kv_cache_unified::init( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to allocate buffer for kv cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } - - return true; } int32_t llama_kv_cache_unified::get_n_tokens() const { @@ -480,7 +458,7 @@ bool llama_kv_cache_unified::find_slot( return true; } -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; } @@ -1021,24 +999,16 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // llama_kv_cache_recurrent // -llama_kv_cache_recurrent::llama_kv_cache_recurrent(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { -} - -bool llama_kv_cache_recurrent::init( - const llama_model & model, - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) { - GGML_UNUSED(cparams); - +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_hparams & hparams, + callbacks cbs, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)) { const int32_t n_layer = hparams.n_layer; - GGML_ASSERT(llama_model_is_recurrent(&model)); - - LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); head = 0; size = kv_size; @@ -1082,25 +1052,11 @@ bool llama_kv_cache_recurrent::init( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft; - if (offload) { - auto * dev = model.dev_layer(i); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); - } - - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + ggml_backend_buffer_type_t buft = cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to create ggml context for kv cache"); } ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); @@ -1118,15 +1074,12 @@ bool llama_kv_cache_recurrent::init( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to allocate buffer for kv cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } - - return true; } int32_t llama_kv_cache_recurrent::get_n_tokens() const { @@ -1558,11 +1511,6 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } -uint32_t llama_kv_cache_recurrent::get_padding(const llama_cparams & cparams) const { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} - uint32_t llama_kv_cache_recurrent::cell_max() const { for (uint32_t i = size; i > 0; --i) { const llama_kv_cell & cell = cells[i - 1]; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index a2d88c9cca2fe..e2e5d02d8edf9 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -15,19 +15,18 @@ struct llama_hparams; struct llama_ubatch; struct llama_kv_cache : public llama_memory_i { + // can be used to query data from the model if needed + struct callbacks { + std::function get_rope_factors; + + // get the buffer type of layer il, can be used to offload KV cache layers to a different device + std::function get_buft; + }; + virtual ~llama_kv_cache() = default; using llama_memory_i::llama_memory_i; - // TODO: become constructor - virtual bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) = 0; - virtual void restore() = 0; // call if batch processing fails - restores the cache state virtual void commit() = 0; // call after successful batch processing - clears any pending state @@ -96,23 +95,13 @@ struct llama_kv_cell { // TODO: add notion of max sequences class llama_kv_cache_unified : public llama_kv_cache { public: - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; - }; - - // TODO: become constructor - bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) override; - llama_kv_cache_unified( const llama_hparams & hparams, - callbacks cbs); + callbacks cbs, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + uint32_t kv_size); ~llama_kv_cache_unified() = default; @@ -149,8 +138,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; - // TODO: maybe not needed - uint32_t get_padding(const llama_cparams & cparams) const; + static uint32_t get_padding(const llama_cparams & cparams); // find how many cells are currently in use uint32_t cell_max() const; @@ -229,26 +217,15 @@ class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; - }; - llama_kv_cache_recurrent( const llama_hparams & hparams, - callbacks cbs); + callbacks cbs, + ggml_type type_k, + ggml_type type_v, + uint32_t kv_size); ~llama_kv_cache_recurrent() = default; - // TODO: become constructor - bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size, - bool offload) override; - int32_t get_n_tokens() const override; int32_t get_used_cells() const override; @@ -282,9 +259,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; - // TODO: maybe not needed - uint32_t get_padding(const llama_cparams & cparams) const; - // find how many cells are currently in use uint32_t cell_max() const; diff --git a/src/llama-memory.h b/src/llama-memory.h index dfa8c4e90fc2a..16407052c26b4 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,6 +2,20 @@ #include "llama.h" +struct llama_memory_params { + // kv cache + ggml_type type_k; + ggml_type type_v; + + bool v_trans; + bool offload_kqv; + + uint32_t kv_size; + + // other types of memory + // ... +}; + // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types class llama_memory_i { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a3ed8fa599cb7..fe5d47c8816c5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12815,9 +12815,29 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory() const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params) const { llama_memory_i * res; + const bool offload = params.offload_kqv; + + auto get_buft = [this, offload](int il) { + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft; + if (offload) { + auto * dev = dev_layer(il); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } else { + buft = ggml_backend_cpu_buffer_type(); + } + + LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name); + + return buft; + }; + switch (arch) { case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: @@ -12825,26 +12845,39 @@ llama_memory_i * llama_model::create_memory() const { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: { - res = new llama_kv_cache_recurrent(hparams, { - /*.get_rope_factors =*/ nullptr - }); + res = new llama_kv_cache_recurrent( + hparams, + { + /*.get_rope_factors =*/ nullptr, + /*.get_buft =*/ get_buft, + }, + params.type_k, + params.type_v, + params.kv_size); } break; default: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } - - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } - - return layers[il].rope_short; - } - }); + res = new llama_kv_cache_unified( + hparams, + { + /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; + }, + /*.get_buft =*/ get_buft, + }, + params.type_k, + params.type_v, + params.v_trans, + params.kv_size); } } diff --git a/src/llama-model.h b/src/llama-model.h index 34aac337cff27..aa450361130ae 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -396,7 +396,7 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory() const; // TODO: params + llama_memory_i * create_memory(const llama_memory_params & params) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( From 49aa8b83b8d98ea30559df86e78b6cf0e8044f8e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 15 Apr 2025 16:58:36 +0300 Subject: [PATCH 03/29] kv-cache : fix callback reference ggml-ci --- src/llama-kv-cache.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0f3c349941195..6f4836dbac219 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -22,11 +22,9 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) { - const int32_t n_layer = hparams.n_layer; has_shift = false; - can_shift = true; LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", @@ -74,7 +72,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = cbs.get_buft(i); + ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -1052,7 +1050,7 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = cbs.get_buft(i); + ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { From 838b3cca384425f5f07deef86e6c39e7f12b75c0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Apr 2025 15:07:11 +0300 Subject: [PATCH 04/29] context : llama_kv_cache -> llama_memory_i ggml-ci --- src/llama-context.cpp | 88 ++++++++++++++++++------------------------ src/llama-context.h | 2 +- src/llama-kv-cache.cpp | 20 ++++++++++ src/llama-memory.h | 2 + 4 files changed, 61 insertions(+), 51 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index c27cb559e10f4..755ad2067aed7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -177,65 +177,35 @@ llama_context::llama_context( } // init the memory module - // TODO: for now, always create a unified KV cache if (!hparams.vocab_only) { - uint32_t kv_size = 0; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); if (!llama_model_is_recurrent(&model)) { - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams)); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - kv_size = cparams.n_ctx; - type_k = params.type_k; - type_v = params.type_v; - llama_memory_params params_mem = { - /*.type_k =*/ type_k, - /*.type_v =*/ type_v, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, /*.v_trans =*/ !cparams.flash_attn, /*.offload_kqv =*/ cparams.offload_kqv, - /*.kv_size =*/ kv_size, + /*.kv_size =*/ cparams.n_ctx, }; - auto * kv = static_cast(model.create_memory(params_mem)); - - kv_self.reset(kv); + memory.reset(model.create_memory(params_mem)); } else { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - llama_memory_params params_mem = { - /*.type_k =*/ type_k, - /*.type_v =*/ type_v, + /*.type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states + /*.type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states /*.v_trans =*/ false, // unused - /*.offload_kqv =*/ params.offload_kqv, - /*.kv_size =*/ kv_size, + /*.offload_kqv =*/ cparams.offload_kqv, + /*.kv_size =*/ std::max((uint32_t) 1, params.n_seq_max), // Mamba needs at least as many KV cells as there are sequences kept at any time }; - auto * kv = static_cast(model.create_memory(params_mem)); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - - kv_self.reset(kv); + memory.reset(model.create_memory(params_mem)); } - { - const size_t memory_size_k = kv_self->size_k_bytes(); - const size_t memory_size_v = kv_self->size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } } // init backends @@ -326,6 +296,8 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->set_full(); cross.v_embd.clear(); @@ -477,11 +449,13 @@ uint32_t llama_context::n_threads_batch() const { } llama_kv_cache * llama_context::get_kv_self() { - return kv_self.get(); + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } const llama_kv_cache * llama_context::get_kv_self() const { - return kv_self.get(); + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } ggml_tensor * llama_context::build_rope_shift( @@ -567,7 +541,7 @@ llm_graph_result_ptr llama_context::build_kv_self_shift( //GGML_ASSERT(kv_self->size == n_ctx); - const auto * kv = static_cast(kv_self.get()); + const auto * kv = static_cast(memory.get()); auto inp = std::make_unique(kv); @@ -609,7 +583,7 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag( ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * kv = static_cast(kv_self.get()); + auto * kv = static_cast(memory.get()); const auto & hparams = model.hparams; @@ -755,6 +729,8 @@ llm_graph_result_ptr llama_context::build_kv_self_defrag( void llama_context::kv_self_update() { bool need_reserve = false; + llama_kv_cache * kv_self = static_cast(memory.get()); + if (kv_self->get_has_shift()) { if (!kv_self->get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); @@ -780,7 +756,7 @@ void llama_context::kv_self_update() { } { - auto * kv = static_cast(kv_self.get()); + auto * kv = static_cast(kv_self); kv->has_shift = false; @@ -794,7 +770,7 @@ void llama_context::kv_self_update() { if (kv_self->get_do_defrag()) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - auto * kv = static_cast(kv_self.get()); + auto * kv = static_cast(kv_self); if (kv->defrag_prepare(graph_max_nodes())) { ggml_backend_sched_reset(sched.get()); @@ -1043,6 +1019,8 @@ int llama_context::encode(llama_batch & inp_batch) { return -1; } + llama_kv_cache * kv_self = static_cast(memory.get()); + // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); @@ -1208,6 +1186,8 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } + llama_kv_cache * kv_self = static_cast(memory.get()); + // temporary allocate memory for the input batch if needed // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); @@ -1222,7 +1202,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self.get()); + llama_kv_cache_guard kv_guard(kv_self); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1326,7 +1306,7 @@ int llama_context::decode(llama_batch & inp_batch) { } if (!is_recurrent) { - auto * kv = static_cast(kv_self.get()); + auto * kv = static_cast(kv_self); // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears @@ -1478,7 +1458,7 @@ int llama_context::decode(llama_batch & inp_batch) { // decide if we need to defrag the kv cache if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) { - auto * kv = static_cast(kv_self.get()); + auto * kv = static_cast(kv_self); // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens @@ -1649,7 +1629,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ kv_self.get(), + /*.memory =*/ memory.get(), /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2108,6 +2088,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_write(io); return io.n_bytes(); @@ -2192,6 +2174,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_read(io); return io.n_bytes(); @@ -2200,6 +2184,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_write(io, seq_id); return io.n_bytes(); @@ -2208,6 +2194,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_read(io, seq_id); return io.n_bytes(); diff --git a/src/llama-context.h b/src/llama-context.h index c39c193da6f20..879eef014f547 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,7 +200,7 @@ struct llama_context { llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr kv_self; + std::unique_ptr memory; // TODO: remove bool logits_all = false; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 6f4836dbac219..82a5065a5e9b6 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -100,6 +100,16 @@ llama_kv_cache_unified::llama_kv_cache_unified( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } int32_t llama_kv_cache_unified::get_n_tokens() const { @@ -1078,6 +1088,16 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } } int32_t llama_kv_cache_recurrent::get_n_tokens() const { diff --git a/src/llama-memory.h b/src/llama-memory.h index 16407052c26b4..715c5a22260a0 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -20,6 +20,8 @@ struct llama_memory_params { // the KV cache is a type of LLM memory, but there can be other types class llama_memory_i { public: + virtual ~llama_memory_i() = default; + virtual void clear() = 0; virtual void defrag() = 0; From 8e4d3baacdb3b2a6dff7cdd355263904f2b08155 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Apr 2025 15:31:06 +0300 Subject: [PATCH 05/29] context : move memory creation logic to model ggml-ci --- src/llama-context.cpp | 32 +++++--------------------------- src/llama-memory.h | 6 +++--- src/llama-model.cpp | 20 +++++++++++++------- src/llama-model.h | 3 ++- 4 files changed, 23 insertions(+), 38 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 755ad2067aed7..3e305fa7964e1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -178,34 +178,12 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - - if (!llama_model_is_recurrent(&model)) { - cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams)); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, - /*.v_trans =*/ !cparams.flash_attn, - /*.offload_kqv =*/ cparams.offload_kqv, - /*.kv_size =*/ cparams.n_ctx, - }; - - memory.reset(model.create_memory(params_mem)); - } else { - llama_memory_params params_mem = { - /*.type_k =*/ GGML_TYPE_F32, // required by ggml_ssm_conv for Mamba's conv_states - /*.type_v =*/ GGML_TYPE_F32, // required by ggml_ssm_scan for Mamba's ssm_states - /*.v_trans =*/ false, // unused - /*.offload_kqv =*/ cparams.offload_kqv, - /*.kv_size =*/ std::max((uint32_t) 1, params.n_seq_max), // Mamba needs at least as many KV cells as there are sequences kept at any time - }; - - memory.reset(model.create_memory(params_mem)); - } + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + }; + memory.reset(model.create_memory(cparams, params_mem)); } // init backends diff --git a/src/llama-memory.h b/src/llama-memory.h index 715c5a22260a0..bd1f6955cbd4d 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -7,10 +7,10 @@ struct llama_memory_params { ggml_type type_k; ggml_type type_v; - bool v_trans; - bool offload_kqv; + //bool v_trans; + //bool offload_kqv; - uint32_t kv_size; + //uint32_t kv_size; // other types of memory // ... diff --git a/src/llama-model.cpp b/src/llama-model.cpp index fe5d47c8816c5..87c1607680756 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12815,10 +12815,10 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory(const llama_memory_params & params) const { +llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama_memory_params & params) const { llama_memory_i * res; - const bool offload = params.offload_kqv; + const bool offload = cparams.offload_kqv; auto get_buft = [this, offload](int il) { const char * dev_name = "CPU"; @@ -12838,6 +12838,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params) return buft; }; + LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); + switch (arch) { case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: @@ -12851,12 +12853,16 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params) /*.get_rope_factors =*/ nullptr, /*.get_buft =*/ get_buft, }, - params.type_k, - params.type_v, - params.kv_size); + GGML_TYPE_F32, + GGML_TYPE_F32, + std::max((uint32_t) 1, cparams.n_seq_max)); } break; default: { + cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams)); + + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + res = new llama_kv_cache_unified( hparams, { @@ -12876,8 +12882,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params) }, params.type_k, params.type_v, - params.v_trans, - params.kv_size); + !cparams.flash_attn, + cparams.n_ctx); } } diff --git a/src/llama-model.h b/src/llama-model.h index aa450361130ae..c7dbf829765a1 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -395,8 +395,9 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory(const llama_memory_params & params) const; + llama_memory_i * create_memory(llama_cparams & cparams, const llama_memory_params & params) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( From 7fec0814e6dbb941661cb207d29e0786d87f65f0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 17 Apr 2025 18:28:00 +0300 Subject: [PATCH 06/29] llama : remove reference of memory during encode ggml-ci --- src/llama-context.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3e305fa7964e1..7f021123f5727 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -179,8 +179,8 @@ llama_context::llama_context( // init the memory module if (!hparams.vocab_only) { llama_memory_params params_mem = { - /*.type_k =*/ params.type_k, - /*.type_v =*/ params.type_v, + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, }; memory.reset(model.create_memory(cparams, params_mem)); @@ -997,11 +997,9 @@ int llama_context::encode(llama_batch & inp_batch) { return -1; } - llama_kv_cache * kv_self = static_cast(memory.get()); - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); + // note: during encode, we always pass the full sequence starting from pos = 0 + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; From 59af92bbacf12efcd4b86315d39b0c321e293535 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 09:55:50 +0300 Subject: [PATCH 07/29] kv-cache : hide padding details in the implementation ggml-ci --- src/llama-context.cpp | 20 +++----------------- src/llama-kv-cache.cpp | 16 +++++++++++++--- src/llama-kv-cache.h | 6 +++++- src/llama-model.cpp | 7 +++++-- 4 files changed, 26 insertions(+), 23 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7f021123f5727..5cfe944829457 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1274,24 +1274,10 @@ int llama_context::decode(llama_batch & inp_batch) { } // find KV slot - { - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - - return 1; - } + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - if (!is_recurrent) { - auto * kv = static_cast(kv_self); - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv->get_padding(cparams); - kv->n = std::min(kv->size, std::max(pad, GGML_PAD(kv->cell_max(), pad))); - - //printf("kv.n = %5d, kv.used = %5d, kv.head = %5d\n", kv->n, kv->used, kv->head); - } + return 1; } ggml_backend_sched_reset(sched.get()); diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 82a5065a5e9b6..abf515bfb2719 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -21,14 +21,17 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_k, ggml_type type_v, bool v_trans, - uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans) { + uint32_t kv_size, + uint32_t padding) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; has_shift = false; can_shift = true; - LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding); + + GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding"); head = 0; size = kv_size; @@ -463,6 +466,13 @@ bool llama_kv_cache_unified::find_slot( pending.ranges.push_back({head, head + n_tokens}); + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + + //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); + return true; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index e2e5d02d8edf9..c902a791d2cbb 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -101,7 +101,8 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_type type_k, ggml_type type_v, bool v_trans, - uint32_t kv_size); + uint32_t kv_size, + uint32_t padding); ~llama_kv_cache_unified() = default; @@ -196,6 +197,9 @@ class llama_kv_cache_unified : public llama_kv_cache { // computed before each graph build uint32_t n = 0; + // required padding + uint32_t padding = 1; + std::vector cells; std::vector k_l; // per layer diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 87c1607680756..554a1e1b39070 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12859,7 +12859,9 @@ llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama } break; default: { - cparams.n_ctx = GGML_PAD(cparams.n_ctx, llama_kv_cache_unified::get_padding(cparams)); + const auto padding = llama_kv_cache_unified::get_padding(cparams); + + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); @@ -12883,7 +12885,8 @@ llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama params.type_k, params.type_v, !cparams.flash_attn, - cparams.n_ctx); + cparams.n_ctx, + padding); } } From 6413b937f0af41b12412a382784195155cfe673c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 10:33:38 +0300 Subject: [PATCH 08/29] kv-cache : add ubatch_next() ggml-ci --- src/llama-context.cpp | 19 ++----------------- src/llama-kv-cache.cpp | 17 +++++++++++++++++ src/llama-kv-cache.h | 8 ++++++++ 3 files changed, 27 insertions(+), 17 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5cfe944829457..18d6f6411e99c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1239,22 +1239,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = llama_ubatch(); - - const auto & n_ubatch = cparams.n_ubatch; - - if (is_recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(cparams.n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = sbatch.split_equal(cparams.n_ubatch); - } - } else { - ubatch = sbatch.split_simple(n_ubatch); - } + llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); // count the outputs in this u_batch { @@ -1424,7 +1409,7 @@ int llama_context::decode(llama_batch & inp_batch) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->get_padding(cparams))/float(kv->n)) : 0.0f; + const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->padding)/float(kv->n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > cparams.defrag_thold) { diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index abf515bfb2719..2422858c6930b 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -476,6 +476,14 @@ bool llama_kv_cache_unified::find_slot( return true; } +llama_ubatch llama_kv_cache_unified::ubatch_next( + llama_sbatch & sbatch, + uint32_t n_ubatch, + bool embd_pooled) const { + GGML_UNUSED(embd_pooled); + return sbatch.split_simple(n_ubatch); +} + uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { // the FA kernels require padding to avoid extra runtime boundary checks return cparams.flash_attn ? 256u : 32u; @@ -1539,6 +1547,15 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } +llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); + } + + return sbatch.split_equal(n_ubatch); +} + uint32_t llama_kv_cache_recurrent::cell_max() const { for (uint32_t i = size; i > 0; --i) { const llama_kv_cell & cell = cells[i - 1]; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index c902a791d2cbb..2df0475d2cdfd 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -13,6 +13,7 @@ struct llama_cparams; struct llama_hparams; struct llama_ubatch; +struct llama_sbatch; struct llama_kv_cache : public llama_memory_i { // can be used to query data from the model if needed @@ -44,6 +45,9 @@ struct llama_kv_cache : public llama_memory_i { virtual bool find_slot(const llama_ubatch & batch) = 0; + // different KV caches require different batch splitting strategies + virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; + // simulate full cache, used for allocating worst-case compute buffers virtual void set_full() = 0; @@ -139,6 +143,8 @@ class llama_kv_cache_unified : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + static uint32_t get_padding(const llama_cparams & cparams); // find how many cells are currently in use @@ -263,6 +269,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + // find how many cells are currently in use uint32_t cell_max() const; From e869515b853ff20629c9b3f842bb1f1003955f3e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 11:50:43 +0300 Subject: [PATCH 09/29] context : simplify sbatch logic ggml-ci --- src/llama-batch.cpp | 6 ++- src/llama-batch.h | 3 +- src/llama-context.cpp | 98 +++++++++++++++++------------------------- src/llama-context.h | 5 --- src/llama-kv-cache.cpp | 12 ++++++ src/llama-kv-cache.h | 6 +++ 6 files changed, 65 insertions(+), 65 deletions(-) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..a88b2fe3082c9 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; @@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim s.length = n_tokens; return; } + std::sort(ids.begin(), ids.end(), [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; @@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim return n_seq_a > n_seq_b; } ); + // init seq llama_sbatch_seq * last_seq = nullptr; @@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim seq.push_back(new_seq); last_seq = &seq.back(); } + // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), [](llama_sbatch_seq & a, llama_sbatch_seq & b) { diff --git a/src/llama-batch.h b/src/llama-batch.h index f1df40d27086e..6305051b62b79 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -70,7 +70,8 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + llama_sbatch() = default; + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 18d6f6411e99c..bc757fc5f1961 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -799,9 +799,6 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { - // reorder logits for backward compatibility - output_reorder(); - return logits; } @@ -844,9 +841,6 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { - // reorder embeddings for backward compatibility - output_reorder(); - return embd; } @@ -1028,7 +1022,7 @@ int llama_context::encode(llama_batch & inp_batch) { const int64_t n_embd = hparams.n_embd; - sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = sbatch.split_simple(n_tokens); @@ -1219,13 +1213,7 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - const bool logits_all = n_outputs_all == n_tokens_all; - - const bool is_recurrent = llama_model_is_recurrent(&model); - - sbatch.from_batch(batch, n_embd, - /* simple_split */ !is_recurrent, - /* logits_all */ logits_all); + llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1382,18 +1370,52 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + auto & out_ids = sbatch.out_ids; + + GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); for (int64_t i = 0; i < n_outputs_all; ++i) { - int64_t out_id = sbatch.out_ids[i]; + int64_t out_id = out_ids[i]; output_ids[out_id] = i; if (out_id != i) { sorted_output = false; } } - if (sorted_output) { - sbatch.out_ids.clear(); + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint32_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + std::fill(output_ids.begin(), output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } } } @@ -1502,44 +1524,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } -void llama_context::output_reorder() { - auto & out_ids = sbatch.out_ids; - if (!out_ids.empty()) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } - } - std::fill(output_ids.begin(), output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - // // graph // @@ -1980,8 +1964,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - output_reorder(); - const auto n_outputs = this->n_outputs; const auto & output_ids = this->output_ids; diff --git a/src/llama-context.h b/src/llama-context.h index 879eef014f547..a211416f85d3b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -137,10 +137,6 @@ struct llama_context { // Returns max number of outputs for which space was reserved. int32_t output_reserve(int32_t n_outputs); - // make the outputs have the same order they had in the user-provided batch - // TODO: maybe remove this - void output_reorder(); - // // graph // @@ -196,7 +192,6 @@ struct llama_context { llama_cparams cparams; llama_adapter_cvec cvec; llama_adapter_loras loras; - llama_sbatch sbatch; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 2422858c6930b..938e54916731a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -476,6 +476,12 @@ bool llama_kv_cache_unified::find_slot( return true; } +llama_sbatch llama_kv_cache_unified::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, true, logits_all); +} + llama_ubatch llama_kv_cache_unified::ubatch_next( llama_sbatch & sbatch, uint32_t n_ubatch, @@ -1547,6 +1553,12 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } +llama_sbatch llama_kv_cache_recurrent::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, false, logits_all); +} + llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { if (embd_pooled) { // Pooled embeddings cannot be split across ubatches (yet) diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 2df0475d2cdfd..92b1f7b90e740 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -45,6 +45,8 @@ struct llama_kv_cache : public llama_memory_i { virtual bool find_slot(const llama_ubatch & batch) = 0; + virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; + // different KV caches require different batch splitting strategies virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; @@ -143,6 +145,8 @@ class llama_kv_cache_unified : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; static uint32_t get_padding(const llama_cparams & cparams); @@ -269,6 +273,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // to the first cell of the slot. bool find_slot(const llama_ubatch & batch) override; + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; // find how many cells are currently in use From ae2cd005c9c3be19f68e8218e1c058195c2114ec Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 12:03:30 +0300 Subject: [PATCH 10/29] kv-cache : hide defrag logic in the implementation ggml-ci --- src/llama-context.cpp | 18 ++++-------------- src/llama-kv-cache.cpp | 16 +++++++++++++--- src/llama-kv-cache.h | 6 ++++-- src/llama-memory.h | 1 - 4 files changed, 21 insertions(+), 20 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index bc757fc5f1961..952452396beb1 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1426,19 +1426,8 @@ int llama_context::decode(llama_batch & inp_batch) { //synchronize(); // decide if we need to defrag the kv cache - if (!llama_model_is_recurrent(&model) && cparams.causal_attn && cparams.defrag_thold > 0.0f) { - auto * kv = static_cast(kv_self); - - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = kv->n >= 2048 ? std::max(0.0f, 1.0f - float(kv->used + kv->padding)/float(kv->n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - kv_self->defrag(); - } + if (cparams.defrag_thold > 0.0f) { + kv_self->defrag(cparams.defrag_thold); } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to @@ -2586,7 +2575,8 @@ void llama_kv_self_defrag(llama_context * ctx) { return; } - return kv->defrag(); + // force defrag + return kv->defrag(-1.0f); } // deprecated diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 938e54916731a..48ad3380b0719 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -357,8 +357,17 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_unified::defrag() { - do_defrag = true; +void llama_kv_cache_unified::defrag(float thold) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - float(used + padding)/float(n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } } void llama_kv_cache_unified::restore() { @@ -1358,7 +1367,8 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_recurrent::defrag() { +void llama_kv_cache_recurrent::defrag(float thold) { + GGML_UNUSED(thold); LLAMA_LOG_ERROR("%s: not supported\n", __func__); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 92b1f7b90e740..f58c2a165daf7 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -31,6 +31,8 @@ struct llama_kv_cache : public llama_memory_i { virtual void restore() = 0; // call if batch processing fails - restores the cache state virtual void commit() = 0; // call after successful batch processing - clears any pending state + virtual void defrag(float thold) = 0; + virtual int32_t get_n_tokens() const = 0; virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache @@ -124,7 +126,7 @@ class llama_kv_cache_unified : public llama_kv_cache { llama_pos get_pos_max() const override; void clear() override; - void defrag() override; + void defrag(float thold) override; void restore() override; void commit() override; @@ -252,7 +254,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache { llama_pos get_pos_max() const override; void clear() override; - void defrag() override; + void defrag(float thold) override; void restore() override; void commit() override; diff --git a/src/llama-memory.h b/src/llama-memory.h index bd1f6955cbd4d..4a8c396529236 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -23,7 +23,6 @@ class llama_memory_i { virtual ~llama_memory_i() = default; virtual void clear() = 0; - virtual void defrag() = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; From fdb7206d39574bcc398120918cdd6da576c1c164 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 16:36:53 +0300 Subject: [PATCH 11/29] context : hide kv cache details in implementation ggml-ci --- src/llama-context.cpp | 337 ++------------------------------------ src/llama-context.h | 17 -- src/llama-graph.h | 8 +- src/llama-kv-cache.cpp | 357 ++++++++++++++++++++++++++++++++++++++++- src/llama-kv-cache.h | 47 +++++- 5 files changed, 415 insertions(+), 351 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 952452396beb1..1fe6e6203b314 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -436,338 +436,21 @@ const llama_kv_cache * llama_context::get_kv_self() const { return kv_self; } -ggml_tensor * llama_context::build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & hparams = model.hparams; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx0, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx0, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx0, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - assert(ggml_backend_buffer_is_host(k_shift->buffer)); - - int32_t * data = (int32_t *) k_shift->data; - - for (uint32_t i = 0; i < kv_self->size; ++i) { - data[i] = kv_self->cells[i].delta; - } - } -} - -llm_graph_result_ptr llama_context::build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & n_layer = hparams.n_layer; - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - const auto * kv = static_cast(memory.get()); - - auto inp = std::make_unique(kv); - - inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (uint32_t il = 0; il < n_layer; ++il) { - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const bool is_swa = hparams.is_swa(il); - - // note: the swa rope params could become part of the cparams in the future - // if we decide to make them configurable, like the non-sliding ones - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - - ggml_tensor * rope_factors = kv->cbs.get_rope_factors(n_ctx_per_seq(), il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv->k_l[il], - n_embd_head_k, n_head_kv, kv->size, - ggml_row_size(kv->k_l[il]->type, n_embd_head_k), - ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_context::build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - auto * kv = static_cast(memory.get()); - - const auto & hparams = model.hparams; - - const auto & ids = kv->defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv->k_l[il]->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx0, kv->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv->v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx0, kv->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv->v_l[il]->type, kv->size), - ggml_row_size(kv->v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx0, kv->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv->v_l[il]->type, kv->size), - ggml_row_size(kv->v_l[il]->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - void llama_context::kv_self_update() { bool need_reserve = false; llama_kv_cache * kv_self = static_cast(memory.get()); - if (kv_self->get_has_shift()) { - if (!kv_self->get_can_shift()) { - GGML_ABORT("The current KV cache / model configuration does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_shift(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } - - { - auto * kv = static_cast(kv_self); - - kv->has_shift = false; - - for (uint32_t i = 0; i < kv->size; ++i) { - kv->cells[i].delta = 0; - } - } - } - - // defragment the KV cache if needed - if (kv_self->get_do_defrag()) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - auto * kv = static_cast(kv_self); - - if (kv->defrag_prepare(graph_max_nodes())) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_defrag(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } - - kv->do_defrag = false; - } + need_reserve = kv_self->update({ + /*.arch =*/ model.arch, + /*.cparams =*/ cparams, + /*.sched =*/ sched.get(), + /*.backends =*/ backends, + /*.n_max_nodes =*/ graph_max_nodes(), + /*.get_ctx_compute =*/ [this]() { return ctx_compute.get(); }, + /*.graph_init =*/ [this]() { return graph_init(); }, + /*.graph_compute =*/ [this](ggml_cgraph * gf) { graph_compute(gf, false); }, + }); // reserve a worst case graph if needed if (need_reserve) { diff --git a/src/llama-context.h b/src/llama-context.h index a211416f85d3b..3e79cff8250bf 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -159,23 +159,6 @@ struct llama_context { llm_graph_cb graph_get_cb() const; - // used by kv_self_update() - ggml_tensor * build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const; - // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); diff --git a/src/llama-graph.h b/src/llama-graph.h index 23397a76fd3bd..5b404366dc145 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -351,8 +351,8 @@ struct llm_graph_params { const llama_cparams & cparams; const llama_ubatch & ubatch; - ggml_backend_sched * sched; - ggml_backend * backend_cpu; + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; @@ -403,9 +403,9 @@ struct llm_graph_context { ggml_context * ctx0 = nullptr; - ggml_backend_sched * sched; + ggml_backend_sched_t sched; - ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 48ad3380b0719..4797f5b128cb3 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -75,7 +75,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); + ggml_backend_buffer_type_t buft = cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -408,6 +408,354 @@ void llama_kv_cache_unified::commit() { pending.ranges.clear(); } +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const graph_params & params, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const { + const auto & arch = params.arch; + const auto & cparams = params.cparams; + const auto & backends = params.backends; + const auto & sched = params.sched; + + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; + + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; + + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type; + + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + + ggml_tensor * tmp; + + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + + if (bbuf) { + for (const auto & backend : backends) { + // Figure out which backend KV cache belongs to + if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { + ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get()); + break; + } + } + } + + tmp = ggml_rope_ext_inplace(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } + + return tmp; +} + +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; + + void set_input(const llama_ubatch * ubatch) override; + + ggml_tensor * k_shift; // I32 [kv_size] + + const llama_kv_cache_unified * kv_self; +}; + +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + assert(ggml_backend_buffer_is_host(k_shift->buffer)); + + int32_t * data = (int32_t *) k_shift->data; + + for (uint32_t i = 0; i < kv_self->size; ++i) { + data[i] = kv_self->cells[i].delta; + } + } +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const graph_params & params, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + auto * ctx = params.get_ctx_compute(); + + const auto & cparams = params.cparams; + + const auto & n_layer = hparams.n_layer; + + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; + + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (uint32_t il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const bool is_swa = hparams.is_swa(il); + + // note: the swa rope params could become part of the cparams in the future + // if we decide to make them configurable, like the non-sliding ones + const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + + ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il); + + ggml_tensor * k = + ggml_view_3d(ctx, k_l[il], + n_embd_head_k, n_head_kv, size, + ggml_row_size(k_l[il]->type, n_embd_head_k), + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(params, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return std::move(res); +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const graph_params & params, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + auto * ctx = params.get_ctx_compute(); + + const auto & ids = defrag_info.ids; + + const auto & cparams = params.cparams; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::update(const graph_params & params) { + bool need_reserve = false; + + const auto & sched = params.sched; + + if (get_has_shift()) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = params.graph_init(); + + auto res = build_graph_shift(params, gf); + + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + params.graph_compute(gf); + + need_reserve = true; + } + + { + has_shift = false; + + for (uint32_t i = 0; i < size; ++i) { + cells[i].delta = 0; + } + } + } + + if (get_do_defrag()) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + if (defrag_prepare(params.n_max_nodes)) { + ggml_backend_sched_reset(sched); + + auto * gf = params.graph_init(); + + auto res = build_graph_defrag(params, gf); + + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + params.graph_compute(gf); + + need_reserve = true; + } + + do_defrag = false; + } + + return need_reserve; +} + bool llama_kv_cache_unified::get_can_shift() const { return can_shift; } @@ -1369,7 +1717,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { void llama_kv_cache_recurrent::defrag(float thold) { GGML_UNUSED(thold); - LLAMA_LOG_ERROR("%s: not supported\n", __func__); + // noop } void llama_kv_cache_recurrent::restore() { @@ -1384,6 +1732,11 @@ void llama_kv_cache_recurrent::commit() { pending.ranges.clear(); } +bool llama_kv_cache_recurrent::update(const graph_params & params) { + GGML_UNUSED(params); + return false; +} + bool llama_kv_cache_recurrent::get_can_shift() const { return false; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index f58c2a165daf7..056861bf7fc8f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,6 +2,7 @@ #include "llama.h" #include "llama-io.h" +#include "llama-graph.h" #include "llama-memory.h" #include "ggml-cpp.h" @@ -24,12 +25,34 @@ struct llama_kv_cache : public llama_memory_i { std::function get_buft; }; + struct graph_params { + const llm_arch arch; + + const llama_cparams & cparams; + + const ggml_backend_sched_t & sched; + + const std::vector & backends; + + int32_t n_max_nodes; + + std::function get_ctx_compute; + + // function for creating ggml graphs + std::function graph_init; + + // function for computing ggml graphs + std::function graph_compute; + }; + virtual ~llama_kv_cache() = default; using llama_memory_i::llama_memory_i; virtual void restore() = 0; // call if batch processing fails - restores the cache state - virtual void commit() = 0; // call after successful batch processing - clears any pending state + virtual void commit() = 0; // call after successful batch processing - clears any pending state + + virtual bool update(const graph_params & params) = 0; virtual void defrag(float thold) = 0; @@ -131,6 +154,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void restore() override; void commit() override; + bool update(const graph_params & params) override; + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; @@ -224,6 +249,24 @@ class llama_kv_cache_unified : public llama_kv_cache { std::vector ctxs; std::vector bufs; + ggml_tensor * build_rope_shift( + const graph_params & params, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const; + + llm_graph_result_ptr build_graph_shift( + const graph_params & params, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const graph_params & params, + ggml_cgraph * gf) const; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -259,6 +302,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void restore() override; void commit() override; + bool update(const graph_params & params) override; + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; void seq_keep(llama_seq_id seq_id) override; From 13d69a524f801d1d99445a278f51b763a48dc9b3 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 16:52:04 +0300 Subject: [PATCH 12/29] build : fix ggml-ci --- src/llama-context.cpp | 2 -- src/llama-kv-cache.cpp | 3 ++- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 1fe6e6203b314..ce391c8f48e94 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,11 +6,9 @@ #include "llama-model.h" #include "llama-kv-cache.h" -#include #include #include #include -#include // // llama_context diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 4797f5b128cb3..076ea0c4440de 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include #include @@ -542,7 +543,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( res->add_input(std::move(inp)); - return std::move(res); + return res; } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( From 5ef7559a11ac6d413805950132c054cae424a5a5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 23 Apr 2025 16:57:54 +0300 Subject: [PATCH 13/29] cont : another fix ggml-ci --- src/llama-kv-cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 076ea0c4440de..ee522fda9c07f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -76,7 +76,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = cbs.get_buft(i); + ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { From 6b50ba752ce571aa14de1343bd91854e42289bf5 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Apr 2025 10:09:09 +0300 Subject: [PATCH 14/29] kv-cache : simplify interface (wip) ggml-ci --- src/llama-context.cpp | 22 ++++++------ src/llama-graph.cpp | 32 ++--------------- src/llama-kv-cache.cpp | 66 +++++++++++++++++++++++----------- src/llama-kv-cache.h | 81 ++++++++++++++++++++++++++---------------- 4 files changed, 108 insertions(+), 93 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ce391c8f48e94..4cbc1b9f3f6fa 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1108,7 +1108,7 @@ int llama_context::decode(llama_batch & inp_batch) { // decide if we need to defrag the kv cache if (cparams.defrag_thold > 0.0f) { - kv_self->defrag(cparams.defrag_thold); + kv_self->defrag_sched(cparams.defrag_thold); } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to @@ -2150,7 +2150,7 @@ void llama_kv_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); + llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_self_seq_cp( @@ -2164,14 +2164,14 @@ void llama_kv_self_seq_cp( return; } - return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } // deprecated void llama_kv_cache_seq_keep( llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_keep(ctx, seq_id); + llama_kv_self_seq_keep(ctx, seq_id); } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { @@ -2180,7 +2180,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { return; } - return kv->seq_keep(seq_id); + kv->seq_keep(seq_id); } // deprecated @@ -2190,7 +2190,7 @@ void llama_kv_cache_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); + llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); } void llama_kv_self_seq_add( @@ -2204,7 +2204,7 @@ void llama_kv_self_seq_add( return; } - return kv->seq_add(seq_id, p0, p1, delta); + kv->seq_add(seq_id, p0, p1, delta); } // deprecated @@ -2214,7 +2214,7 @@ void llama_kv_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); + llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); } void llama_kv_self_seq_div( @@ -2228,7 +2228,7 @@ void llama_kv_self_seq_div( return; } - return kv->seq_div(seq_id, p0, p1, d); + kv->seq_div(seq_id, p0, p1, d); } // deprecated @@ -2247,7 +2247,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { // deprecated void llama_kv_cache_defrag(llama_context * ctx) { - return llama_kv_self_defrag(ctx); + llama_kv_self_defrag(ctx); } void llama_kv_self_defrag(llama_context * ctx) { @@ -2257,7 +2257,7 @@ void llama_kv_self_defrag(llama_context * ctx) { } // force defrag - return kv->defrag(-1.0f); + kv->defrag_sched(-1.0f); } // deprecated diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index cd3896be05c97..0da4e7d2b0547 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_copy(i); } } } @@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_mask(i); } } } diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index ee522fda9c07f..95a105219093a 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -130,14 +130,6 @@ int32_t llama_kv_cache_unified::get_used_cells() const { return used; } -bool llama_kv_cache_unified::get_has_shift() const { - return has_shift; -} - -bool llama_kv_cache_unified::get_do_defrag() const { - return do_defrag; -} - size_t llama_kv_cache_unified::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -358,10 +350,10 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_unified::defrag(float thold) { +void llama_kv_cache_unified::defrag_sched(float thold) { // - do not defrag small contexts (i.e. < 2048 tokens) // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - float(used + padding)/float(n)) : 0.0f; + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; // queue defragmentation for next llama_kv_cache_update if (fragmentation > thold) { @@ -699,7 +691,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) { const auto & sched = params.sched; - if (get_has_shift()) { + if (has_shift) { if (!get_can_shift()) { GGML_ABORT("The current KV cache / model configuration does not support K-shift"); } @@ -732,7 +724,7 @@ bool llama_kv_cache_unified::update(const graph_params & params) { } } - if (get_do_defrag()) { + if (do_defrag) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); if (defrag_prepare(params.n_max_nodes)) { @@ -1496,14 +1488,6 @@ int32_t llama_kv_cache_recurrent::get_used_cells() const { return used; } -bool llama_kv_cache_recurrent::get_has_shift() const { - return false; -} - -bool llama_kv_cache_recurrent::get_do_defrag() const { - return false; -} - size_t llama_kv_cache_recurrent::total_size() const { size_t size = 0; for (const auto & buf : bufs) { @@ -1716,7 +1700,7 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_recurrent::defrag(float thold) { +void llama_kv_cache_recurrent::defrag_sched(float thold) { GGML_UNUSED(thold); // noop } @@ -1742,6 +1726,46 @@ bool llama_kv_cache_recurrent::get_can_shift() const { return false; } +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + llama_kv_cell & kv_cell = const_cast(cells[i]); + + // prevent out-of-bound sources + if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) { + kv_cell.src = cell_id; + } + + int32_t res = kv_cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (kv_cell.src != (int32_t) cell_id) { + kv_cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + llama_kv_cell & kv_cell = const_cast(cells[i]); + + float res = (float) (kv_cell.src >= 0); + + // only clear once + if (kv_cell.src < 0) { + kv_cell.src = cell_id; + } + + return res; +} + bool llama_kv_cache_recurrent::find_slot( const llama_ubatch & ubatch) { const uint32_t n_tokens = ubatch.n_tokens; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 056861bf7fc8f..cc390d09a5343 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -47,44 +47,55 @@ struct llama_kv_cache : public llama_memory_i { virtual ~llama_kv_cache() = default; - using llama_memory_i::llama_memory_i; + // call if batch processing fails - restores the cache state + virtual void restore() = 0; - virtual void restore() = 0; // call if batch processing fails - restores the cache state - virtual void commit() = 0; // call after successful batch processing - clears any pending state + // call after successful batch processing - clears any pending state + virtual void commit() = 0; + // process any pending defrag/shift/etc. operations + // optionally call once before processing a new batch virtual bool update(const graph_params & params) = 0; - virtual void defrag(float thold) = 0; + // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + virtual void defrag_sched(float thold) = 0; - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + // simulate full cache, used for allocating worst-case compute buffers + virtual void set_full() = 0; - virtual bool get_has_shift() const = 0; - virtual bool get_do_defrag() const = 0; + // + // batch processing + // - virtual llama_pos get_pos_max() const = 0; - - virtual bool get_can_shift() const = 0; + virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; - bool get_can_edit() const override { return get_can_shift(); } + // different KV caches require different batch splitting strategies + virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; virtual bool find_slot(const llama_ubatch & batch) = 0; - virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; + // getters + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache - // different KV caches require different batch splitting strategies - virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; + virtual llama_pos get_pos_max() const = 0; - // simulate full cache, used for allocating worst-case compute buffers - virtual void set_full() = 0; + virtual bool get_can_shift() const = 0; + + bool get_can_edit() const override { return get_can_shift(); } - virtual size_t size_k_bytes() const = 0; - virtual size_t size_v_bytes() const = 0; + // + // state write/read + // virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; +// +// llama_kv_cache_guard +// + struct llama_kv_cache_guard { llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} @@ -100,6 +111,8 @@ struct llama_kv_cache_guard { llama_kv_cache * kv; }; +// TODO: create separate cells for unified/recurrent caches +// TODO: move in the source file struct llama_kv_cell { llama_pos pos = -1; llama_pos delta = 0; @@ -121,7 +134,11 @@ struct llama_kv_cell { } }; +// +// llama_kv_cache_unified // ring-buffer of cached KV data +// + // TODO: pimpl // TODO: add notion of max sequences class llama_kv_cache_unified : public llama_kv_cache { @@ -140,16 +157,13 @@ class llama_kv_cache_unified : public llama_kv_cache { int32_t get_n_tokens() const override; int32_t get_used_cells() const override; - bool get_has_shift() const override; - bool get_do_defrag() const override; - size_t total_size() const; // TODO: better data structures to reduce the cost of this operation llama_pos get_pos_max() const override; void clear() override; - void defrag(float thold) override; + void defrag_sched(float thold) override; void restore() override; void commit() override; @@ -183,8 +197,8 @@ class llama_kv_cache_unified : public llama_kv_cache { void set_full() override; - size_t size_k_bytes() const override; - size_t size_v_bytes() const override; + size_t size_k_bytes() const; + size_t size_v_bytes() const; // defrag @@ -274,6 +288,10 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; +// +// llama_kv_cache_recurrent +// + class llama_kv_cache_recurrent : public llama_kv_cache { public: llama_kv_cache_recurrent( @@ -288,16 +306,13 @@ class llama_kv_cache_recurrent : public llama_kv_cache { int32_t get_n_tokens() const override; int32_t get_used_cells() const override; - bool get_has_shift() const override; - bool get_do_defrag() const override; - size_t total_size() const; // TODO: better data structures to reduce the cost of this operation llama_pos get_pos_max() const override; void clear() override; - void defrag(float thold) override; + void defrag_sched(float thold) override; void restore() override; void commit() override; @@ -314,6 +329,10 @@ class llama_kv_cache_recurrent : public llama_kv_cache { bool get_can_shift() const override; + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + // find an empty slot of size "n_tokens" in the cache // updates the cache head // Note: On success, it's important that cache.head points @@ -329,8 +348,8 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void set_full() override; - size_t size_k_bytes() const override; - size_t size_v_bytes() const override; + size_t size_k_bytes() const; + size_t size_v_bytes() const; // commit/restore cache From cb02ac80861dbf04df8943d86b2984e43117a42f Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Apr 2025 11:56:36 +0300 Subject: [PATCH 15/29] kv-cache : use separate KV cell structs for unified/recurrent ggml-ci --- src/llama-kv-cache.cpp | 78 +++++++++++++++++++----------------------- src/llama-kv-cache.h | 66 +++++++++++++++++++++-------------- 2 files changed, 77 insertions(+), 67 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 95a105219093a..e302822e36bea 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -152,8 +152,6 @@ void llama_kv_cache_unified::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; } head = 0; used = 0; @@ -190,7 +188,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } cells[i].pos = -1; - cells[i].src = -1; if (new_head == size) { new_head = i; @@ -245,7 +242,6 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { } cells[i].pos = -1; - cells[i].src = -1; cells[i].seq_id.clear(); if (new_head == size){ @@ -380,7 +376,6 @@ void llama_kv_cache_unified::restore() { } cells[i].pos = -1; - cells[i].src = -1; } new_head = std::min(new_head, range.c0); @@ -847,7 +842,7 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { uint32_t llama_kv_cache_unified::cell_max() const { for (uint32_t i = size; i > 0; --i) { - const llama_kv_cell & cell = cells[i - 1]; + const kv_cell & cell = cells[i - 1]; if (cell.pos >= 0 && !cell.is_empty()) { return i; @@ -983,7 +978,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { cells[i0 + nf] = cell1; // clear the old cell and move the head there - cell1 = llama_kv_cell(); + cell1 = kv_cell(); head = n_used; if (!cont) { @@ -1226,7 +1221,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell clear(); for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = cells[i]; + kv_cell & cell = cells[i]; llama_pos pos; uint32_t n_seq_id; @@ -1538,7 +1533,7 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p if (0 <= seq_id) { int32_t & tail_id = cells[seq_id].tail; if (tail_id >= 0) { - const llama_kv_cell & cell = cells[tail_id]; + const kv_cell & cell = cells[tail_id]; // partial intersection is invalid if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { return false; @@ -1572,23 +1567,22 @@ void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_ } if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - llama_kv_cell & tail_src = cells[seq_id_src]; - llama_kv_cell & tail_dst = cells[seq_id_dst]; + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; if (tail_dst.tail >= 0) { // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cells[tail_dst.tail]; + kv_cell & cell_dst = cells[tail_dst.tail]; cell_dst.seq_id.erase(seq_id_dst); tail_dst.tail = -1; if (cell_dst.seq_id.empty()) { cell_dst.pos = -1; - cell_dst.delta = -1; cell_dst.src = -1; used -= 1; } } if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cells[tail_src.tail]; + kv_cell & cell_src = cells[tail_src.tail]; cell_src.seq_id.insert(seq_id_dst); tail_dst.tail = tail_src.tail; @@ -1650,7 +1644,7 @@ void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_ if (0 <= seq_id && seq_id < (int64_t) size) { const int32_t tail_id = cells[seq_id].tail; if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; + kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { cell.pos += delta; } @@ -1680,7 +1674,7 @@ void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_ if (0 <= seq_id && seq_id < (int64_t) size) { const int32_t tail_id = cells[seq_id].tail; if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; + kv_cell & cell = cells[tail_id]; if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { cell.pos /= d; } @@ -1731,19 +1725,19 @@ int32_t llama_kv_cache_recurrent::s_copy(int i) const { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(cells[i]); + kv_cell & cell = const_cast(cells[i]); // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= size) { - kv_cell.src = cell_id; + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; } - int32_t res = kv_cell.src; + int32_t res = cell.src; // TODO: do not mutate the KV cache // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; } return res; @@ -1754,13 +1748,13 @@ float llama_kv_cache_recurrent::s_mask(int i) const { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(cells[i]); + kv_cell & cell = const_cast(cells[i]); - float res = (float) (kv_cell.src >= 0); + float res = (float) (cell.src >= 0); // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; + if (cell.src < 0) { + cell.src = cell_id; } return res; @@ -1802,9 +1796,9 @@ bool llama_kv_cache_recurrent::find_slot( return false; } if (j > 0) { - llama_kv_cell & seq = cells[seq_id]; + kv_cell & seq = cells[seq_id]; if (seq.tail >= 0) { - llama_kv_cell & cell = cells[seq.tail]; + kv_cell & cell = cells[seq.tail]; // clear cells from seq_ids that become shared // (should not normally happen, but let's handle it anyway) cell.seq_id.erase(seq_id); @@ -1824,7 +1818,7 @@ bool llama_kv_cache_recurrent::find_slot( std::vector tails_verif; tails_verif.assign(size, -1); for (uint32_t i = 0; i < size; ++i) { - llama_kv_cell & cell = cells[i]; + kv_cell & cell = cells[i]; for (llama_seq_id seq_id : cell.seq_id) { if (tails_verif[seq_id] != -1) { LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); @@ -1845,7 +1839,7 @@ bool llama_kv_cache_recurrent::find_slot( for (uint32_t i = 0; i < size; ++i) { if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; + kv_cell & cell = cells[next_empty_cell]; if (cell.is_empty()) { break; } next_empty_cell += 1; } @@ -1853,20 +1847,20 @@ bool llama_kv_cache_recurrent::find_slot( // find usable cell range for (uint32_t s = 0; s < n_seqs; ++s) { const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cells[seq_id]; + kv_cell & seq_meta = cells[seq_id]; bool has_cell = false; if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cells[seq_meta.tail]; + kv_cell & cell = cells[seq_meta.tail]; GGML_ASSERT(cell.has_seq_id(seq_id)); // does this seq_id "own" the cell? if (cell.seq_id.size() == 1) { has_cell = true; } } if (!has_cell) { - llama_kv_cell & empty_cell = cells[next_empty_cell]; + kv_cell & empty_cell = cells[next_empty_cell]; GGML_ASSERT(empty_cell.is_empty()); // copy old tail into the empty cell if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cells[seq_meta.tail]; + kv_cell & orig_cell = cells[seq_meta.tail]; empty_cell.pos = orig_cell.pos; empty_cell.src = orig_cell.src; orig_cell.seq_id.erase(seq_id); @@ -1878,7 +1872,7 @@ bool llama_kv_cache_recurrent::find_slot( next_empty_cell += 1; for (uint32_t i = 0; i < size; ++i) { if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; + kv_cell & cell = cells[next_empty_cell]; if (cell.is_empty()) { break; } next_empty_cell += 1; } @@ -1893,8 +1887,8 @@ bool llama_kv_cache_recurrent::find_slot( int32_t dst_id = s + min; int32_t src_id = cells[ubatch.seq_id[s][0]].tail; if (dst_id != src_id) { - llama_kv_cell & dst_cell = cells[dst_id]; - llama_kv_cell & src_cell = cells[src_id]; + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; std::swap(dst_cell.pos, src_cell.pos); std::swap(dst_cell.src, src_cell.src); @@ -1914,7 +1908,7 @@ bool llama_kv_cache_recurrent::find_slot( for (uint32_t s = 0; s < n_seqs; ++s) { const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; int32_t cell_id = s + min; - llama_kv_cell & cell = cells[cell_id]; + kv_cell & cell = cells[cell_id]; if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { // What should happen when the pos backtracks or skips a value? @@ -1935,7 +1929,7 @@ bool llama_kv_cache_recurrent::find_slot( head = min; n = max - min + 1; used = std::count_if(cells.begin(), cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + [](const kv_cell & cell){ return !cell.is_empty(); }); // sanity check return n >= n_seqs; @@ -1958,7 +1952,7 @@ llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32 uint32_t llama_kv_cache_recurrent::cell_max() const { for (uint32_t i = size; i > 0; --i) { - const llama_kv_cell & cell = cells[i - 1]; + const kv_cell & cell = cells[i - 1]; if (cell.pos >= 0 && !cell.is_empty()) { return i; @@ -2200,7 +2194,7 @@ bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t ce clear(); for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = cells[i]; + kv_cell & cell = cells[i]; llama_pos pos; uint32_t n_seq_id; @@ -2412,7 +2406,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = kvu->cells; + const std::vector & kv_cells = kvu->cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index cc390d09a5343..0ef3bf2091c65 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -111,29 +111,6 @@ struct llama_kv_cache_guard { llama_kv_cache * kv; }; -// TODO: create separate cells for unified/recurrent caches -// TODO: move in the source file -struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; - - std::set seq_id; - - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } - - bool is_empty() const { - return seq_id.empty(); - } - - bool is_same_seq(const llama_kv_cell & other) const { - return seq_id == other.seq_id; - } -}; - // // llama_kv_cache_unified // ring-buffer of cached KV data @@ -143,6 +120,25 @@ struct llama_kv_cell { // TODO: add notion of max sequences class llama_kv_cache_unified : public llama_kv_cache { public: + struct kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + llama_kv_cache_unified( const llama_hparams & hparams, callbacks cbs, @@ -251,7 +247,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // required padding uint32_t padding = 1; - std::vector cells; + std::vector cells; std::vector k_l; // per layer std::vector v_l; @@ -294,6 +290,26 @@ class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used by recurrent state models to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + llama_kv_cache_recurrent( const llama_hparams & hparams, callbacks cbs, @@ -384,7 +400,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache { // computed before each graph build uint32_t n = 0; - std::vector cells; + std::vector cells; std::vector k_l; // per layer std::vector v_l; From f584750df535cfebd2dabf0d698108f0c20258fa Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Apr 2025 16:25:41 +0300 Subject: [PATCH 16/29] kv-cache : clean-up ggml-ci --- src/llama-kv-cache.cpp | 752 ++++++++++++++++++++--------------------- src/llama-kv-cache.h | 224 ++++++------ 2 files changed, 490 insertions(+), 486 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index e302822e36bea..c8df50a66c0c2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -16,6 +16,11 @@ // llama_kv_cache_unified // +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; +} + llama_kv_cache_unified::llama_kv_cache_unified( const llama_hparams & hparams, callbacks cbs, @@ -23,7 +28,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_type type_v, bool v_trans, uint32_t kv_size, - uint32_t padding) : hparams(hparams), cbs(std::move(cbs)), v_trans(v_trans), padding(padding) { + uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; has_shift = false; @@ -116,38 +121,6 @@ llama_kv_cache_unified::llama_kv_cache_unified( } } -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -llama_pos llama_kv_cache_unified::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - void llama_kv_cache_unified::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; @@ -346,19 +319,6 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_unified::defrag_sched(float thold) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - do_defrag = true; - } -} - void llama_kv_cache_unified::restore() { if (pending.ranges.empty()) { return; @@ -396,6 +356,229 @@ void llama_kv_cache_unified::commit() { pending.ranges.clear(); } +bool llama_kv_cache_unified::update(const graph_params & params) { + bool need_reserve = false; + + const auto & sched = params.sched; + + if (has_shift) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } + + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); + + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); + + auto * gf = params.graph_init(); + + auto res = build_graph_shift(params, gf); + + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + params.graph_compute(gf); + + need_reserve = true; + } + + { + has_shift = false; + + for (uint32_t i = 0; i < size; ++i) { + cells[i].delta = 0; + } + } + } + + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); + + if (defrag_prepare(params.n_max_nodes)) { + ggml_backend_sched_reset(sched); + + auto * gf = params.graph_init(); + + auto res = build_graph_defrag(params, gf); + + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + params.graph_compute(gf); + + need_reserve = true; + } + + do_defrag = false; + } + + return need_reserve; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +void llama_kv_cache_unified::set_full() { + n = size; +} + +llama_sbatch llama_kv_cache_unified::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, true, logits_all); +} + +llama_ubatch llama_kv_cache_unified::ubatch_next( + llama_sbatch & sbatch, + uint32_t n_ubatch, + bool embd_pooled) const { + GGML_UNUSED(embd_pooled); + return sbatch.split_simple(n_ubatch); +} + +bool llama_kv_cache_unified::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; + } + + // otherwise, one cell per token. + + if (n_tokens > size) { + LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); + return false; + } + + uint32_t n_tested = 0; + + while (true) { + if (head + n_tokens > size) { + n_tested += size - head; + head = 0; + continue; + } + + bool found = true; + for (uint32_t i = 0; i < n_tokens; i++) { + if (cells[head + i].pos >= 0) { + found = false; + head += i + 1; + n_tested += i + 1; + break; + } + } + + if (found) { + break; + } + + if (n_tested >= size) { + //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); + return false; + } + } + + for (uint32_t s = 0; s < n_seqs; s++) { + for (uint32_t i = 0; i < n_seq_tokens; ++i) { + uint32_t k = s*n_seq_tokens + i; + cells[head + k].pos = ubatch.pos[k]; + + for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { + cells[head + k].seq_id.insert(ubatch.seq_id[s][j]); + } + } + } + + used += n_tokens; + + pending.ranges.push_back({head, head + n_tokens}); + + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + + //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); + + return true; +} + +int32_t llama_kv_cache_unified::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_kv_cache_unified::get_used_cells() const { + return used; +} + +bool llama_kv_cache_unified::get_can_shift() const { + return can_shift; +} + +llama_pos llama_kv_cache_unified::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); + } + + return pos_max; +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_unified::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_unified::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + ggml_tensor * llama_kv_cache_unified::build_rope_shift( const graph_params & params, ggml_context * ctx, @@ -615,265 +798,70 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); } #else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(k_l[il]->type, n_embd_k_gqa), - ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx, v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx, v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(v_l[il]->type, n_embd_v_gqa), - ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx, v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx, v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(v_l[il]->type, size), - ggml_row_size(v_l[il]->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; -} - -bool llama_kv_cache_unified::update(const graph_params & params) { - bool need_reserve = false; - - const auto & sched = params.sched; - - if (has_shift) { - if (!get_can_shift()) { - GGML_ABORT("The current KV cache / model configuration does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched); - - auto * gf = params.graph_init(); - - auto res = build_graph_shift(params, gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - params.graph_compute(gf); - - need_reserve = true; - } - - { - has_shift = false; - - for (uint32_t i = 0; i < size; ++i) { - cells[i].delta = 0; - } - } - } - - if (do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (defrag_prepare(params.n_max_nodes)) { - ggml_backend_sched_reset(sched); - - auto * gf = params.graph_init(); - - auto res = build_graph_defrag(params, gf); - - ggml_backend_sched_alloc_graph(sched, gf); - - res->set_inputs(nullptr); - - params.graph_compute(gf); - - need_reserve = true; - } - - do_defrag = false; - } - - return need_reserve; -} - -bool llama_kv_cache_unified::get_can_shift() const { - return can_shift; -} - -bool llama_kv_cache_unified::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { - head = 0; - } - - // otherwise, one cell per token. - - if (n_tokens > size) { - LLAMA_LOG_ERROR("%s: n_tokens = %d > size = %d\n", __func__, n_tokens, size); - return false; - } - - uint32_t n_tested = 0; - - while (true) { - if (head + n_tokens > size) { - n_tested += size - head; - head = 0; - continue; - } - - bool found = true; - for (uint32_t i = 0; i < n_tokens; i++) { - if (cells[head + i].pos >= 0) { - found = false; - head += i + 1; - n_tested += i + 1; - break; - } - } - - if (found) { - break; - } - - if (n_tested >= size) { - //LLAMA_LOG_ERROR("%s: failed to find a slot for %d tokens\n", __func__, n_tokens); - return false; - } - } - - for (uint32_t s = 0; s < n_seqs; s++) { - for (uint32_t i = 0; i < n_seq_tokens; ++i) { - uint32_t k = s*n_seq_tokens + i; - cells[head + k].pos = ubatch.pos[k]; - - for (int32_t j = 0; j < ubatch.n_seq_id[s]; j++) { - cells[head + k].seq_id.insert(ubatch.seq_id[s][j]); - } - } - } - - used += n_tokens; - - pending.ranges.push_back({head, head + n_tokens}); - - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); - - //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); - - return true; -} - -llama_sbatch llama_kv_cache_unified::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, true, logits_all); -} - -llama_ubatch llama_kv_cache_unified::ubatch_next( - llama_sbatch & sbatch, - uint32_t n_ubatch, - bool embd_pooled) const { - GGML_UNUSED(embd_pooled); - return sbatch.split_simple(n_ubatch); -} + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; -} + if (i == id || id == ids.size()) { + continue; + } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const kv_cell & cell = cells[i - 1]; + uint32_t nm = 1; - if (cell.pos >= 0 && !cell.is_empty()) { - return i; + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; } - } - return 0; -} + for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); -void llama_kv_cache_unified::set_full() { - n = size; -} + ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); -size_t llama_kv_cache_unified::size_k_bytes() const { - size_t size_k_bytes = 0; + ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); - for (const auto & k : k_l) { - size_k_bytes += ggml_nbytes(k); - } + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; - return size_k_bytes; -} + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); -size_t llama_kv_cache_unified::size_v_bytes() const { - size_t size_v_bytes = 0; + view_v_dst = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, i)); - for (const auto & v : v_l) { - size_v_bytes += ggml_nbytes(v); + view_v_dst = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; } - return size_v_bytes; + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; } bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { @@ -1013,6 +1001,18 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return true; } +uint32_t llama_kv_cache_unified::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -1381,7 +1381,7 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( callbacks cbs, ggml_type type_k, ggml_type type_v, - uint32_t kv_size) : hparams(hparams), cbs(std::move(cbs)) { + uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -1469,38 +1469,6 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( } } -int32_t llama_kv_cache_recurrent::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_recurrent::get_used_cells() const { - return used; -} - -size_t llama_kv_cache_recurrent::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} - -llama_pos llama_kv_cache_recurrent::get_pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); - } - - return pos_max; -} - void llama_kv_cache_recurrent::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; @@ -1694,11 +1662,6 @@ llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_recurrent::defrag_sched(float thold) { - GGML_UNUSED(thold); - // noop -} - void llama_kv_cache_recurrent::restore() { if (pending.ranges.empty()) { return; @@ -1716,48 +1679,28 @@ bool llama_kv_cache_recurrent::update(const graph_params & params) { return false; } -bool llama_kv_cache_recurrent::get_can_shift() const { - return false; +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop } -int32_t llama_kv_cache_recurrent::s_copy(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[i]); - - // prevent out-of-bound sources - if (cell.src < 0 || (uint32_t) cell.src >= size) { - cell.src = cell_id; - } - - int32_t res = cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (cell.src != (int32_t) cell_id) { - cell.src = cell_id; - } - - return res; +void llama_kv_cache_recurrent::set_full() { + n = size; } -float llama_kv_cache_recurrent::s_mask(int i) const { - const uint32_t cell_id = i + head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[i]); - - float res = (float) (cell.src >= 0); +llama_sbatch llama_kv_cache_recurrent::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, false, logits_all); +} - // only clear once - if (cell.src < 0) { - cell.src = cell_id; +llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); } - return res; + return sbatch.split_equal(n_ubatch); } bool llama_kv_cache_recurrent::find_slot( @@ -1935,19 +1878,71 @@ bool llama_kv_cache_recurrent::find_slot( return n >= n_seqs; } -llama_sbatch llama_kv_cache_recurrent::sbatch_init( - const llama_batch & batch, - bool logits_all) { - return llama_sbatch(batch, hparams.n_embd, false, logits_all); +int32_t llama_kv_cache_recurrent::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; } -llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - return sbatch.split_seq(n_ubatch); +int32_t llama_kv_cache_recurrent::get_used_cells() const { + return used; +} + +llama_pos llama_kv_cache_recurrent::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); } - return sbatch.split_equal(n_ubatch); + return pos_max; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[i]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[i]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; } uint32_t llama_kv_cache_recurrent::cell_max() const { @@ -1962,8 +1957,13 @@ uint32_t llama_kv_cache_recurrent::cell_max() const { return 0; } -void llama_kv_cache_recurrent::set_full() { - n = size; +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; } size_t llama_kv_cache_recurrent::size_k_bytes() const { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 0ef3bf2091c65..95806ae27b275 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -72,15 +72,14 @@ struct llama_kv_cache : public llama_memory_i { // different KV caches require different batch splitting strategies virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; + // find an empty slot of size "n_tokens" in the cache virtual bool find_slot(const llama_ubatch & batch) = 0; // getters - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache - - virtual llama_pos get_pos_max() const = 0; - - virtual bool get_can_shift() const = 0; + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + virtual llama_pos get_pos_max() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } @@ -113,10 +112,8 @@ struct llama_kv_cache_guard { // // llama_kv_cache_unified -// ring-buffer of cached KV data // -// TODO: pimpl // TODO: add notion of max sequences class llama_kv_cache_unified : public llama_kv_cache { public: @@ -139,6 +136,8 @@ class llama_kv_cache_unified : public llama_kv_cache { } }; + static uint32_t get_padding(const llama_cparams & cparams); + llama_kv_cache_unified( const llama_hparams & hparams, callbacks cbs, @@ -150,21 +149,11 @@ class llama_kv_cache_unified : public llama_kv_cache { ~llama_kv_cache_unified() = default; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - size_t total_size() const; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; + // + // llama_memory_i + // void clear() override; - void defrag_sched(float thold) override; - - void restore() override; - void commit() override; - - bool update(const graph_params & params) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -174,66 +163,41 @@ class llama_kv_cache_unified : public llama_kv_cache { llama_pos seq_pos_max(llama_seq_id seq_id) const override; - bool get_can_shift() const override; - - // find an empty slot of size "n_tokens" in the cache - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; - - llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + // + // llama_kv_cache + // - llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + void restore() override; + void commit() override; - static uint32_t get_padding(const llama_cparams & cparams); + bool update(const graph_params & params) override; - // find how many cells are currently in use - uint32_t cell_max() const; + void defrag_sched(float thold) override; void set_full() override; - size_t size_k_bytes() const; - size_t size_v_bytes() const; - - // defrag + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; - struct { - std::vector ids; - } defrag_info; + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - // return true if cells have been moved - bool defrag_prepare(int32_t n_max_nodes); + // updates the cache head + // Note: On success, it's important that cache.head points + // to the first cell of the slot. + bool find_slot(const llama_ubatch & batch) override; - // commit/restore cache + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; + bool get_can_shift() const override; // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // members - - const llama_hparams & hparams; - - callbacks cbs; - - bool has_shift = false; - bool do_defrag = false; - - bool v_trans = true; // the value tensor is transposed - bool can_shift = false; - // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. @@ -244,8 +208,7 @@ class llama_kv_cache_unified : public llama_kv_cache { // computed before each graph build uint32_t n = 0; - // required padding - uint32_t padding = 1; + callbacks cbs; std::vector cells; @@ -253,12 +216,50 @@ class llama_kv_cache_unified : public llama_kv_cache { std::vector v_l; private: + const llama_hparams & hparams; + + bool has_shift = false; + bool do_defrag = false; + + bool v_trans = true; // the value tensor is transposed + bool can_shift = false; + + // required padding + uint32_t padding = 1; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; std::vector ctxs; std::vector bufs; + // defrag + struct { + std::vector ids; + } defrag_info; + + // return true if cells have been moved + bool defrag_prepare(int32_t n_max_nodes); + + // commit/restore cache + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + + // find how many cells are currently in use + uint32_t cell_max() const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + ggml_tensor * build_rope_shift( const graph_params & params, ggml_context * ctx, @@ -291,9 +292,9 @@ class llama_kv_cache_unified : public llama_kv_cache { class llama_kv_cache_recurrent : public llama_kv_cache { public: struct kv_cell { - llama_pos pos = -1; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; std::set seq_id; @@ -319,21 +320,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache { ~llama_kv_cache_recurrent() = default; - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; - - size_t total_size() const; - - // TODO: better data structures to reduce the cost of this operation - llama_pos get_pos_max() const override; + // + // llama_memory_i + // void clear() override; - void defrag_sched(float thold) override; - - void restore() override; - void commit() override; - - bool update(const graph_params & params) override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -343,51 +334,42 @@ class llama_kv_cache_recurrent : public llama_kv_cache { llama_pos seq_pos_max(llama_seq_id seq_id) const override; - bool get_can_shift() const override; + // + // llama_kv_cache + // - // TODO: temporary methods - they are not really const as they do const_cast<>, fix this - int32_t s_copy(int i) const; - float s_mask(int i) const; + void restore() override; + void commit() override; - // find an empty slot of size "n_tokens" in the cache - // updates the cache head - // Note: On success, it's important that cache.head points - // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch) override; + bool update(const graph_params & params) override; + + void defrag_sched(float thold) override; + + void set_full() override; llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - // find how many cells are currently in use - uint32_t cell_max() const; - - void set_full() override; + bool find_slot(const llama_ubatch & batch) override; - size_t size_k_bytes() const; - size_t size_v_bytes() const; + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; - // commit/restore cache + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; - struct slot_range { - uint32_t c0 = 0; // note: these are cell indices, not sequence positions - uint32_t c1 = 0; - }; + bool get_can_shift() const override; - // pending cell updates that are not yet committed - struct { - std::vector ranges; - } pending; + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; // state write/load void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - // members - - const llama_hparams & hparams; - callbacks cbs; // Note: The value of head isn't only used to optimize searching @@ -406,12 +388,34 @@ class llama_kv_cache_recurrent : public llama_kv_cache { std::vector v_l; private: + const llama_hparams & hparams; + + // commit/restore cache + // TODO: rework for recurrent cache + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; std::vector ctxs; std::vector bufs; + // find how many cells are currently in use + uint32_t cell_max() const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; From 458f2a5f69d5e809f0c37a4616d249003e415e3c Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 24 Apr 2025 16:28:30 +0300 Subject: [PATCH 17/29] model : better llama_model::create_model() signature ggml-ci --- src/llama-context.cpp | 2 +- src/llama-model.cpp | 2 +- src/llama-model.h | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 4cbc1b9f3f6fa..3bc9fbf0214a3 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -181,7 +181,7 @@ llama_context::llama_context( /*.type_v =*/ params.type_v, }; - memory.reset(model.create_memory(cparams, params_mem)); + memory.reset(model.create_memory(params_mem, cparams)); } // init backends diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 554a1e1b39070..f02abb9ab9818 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12815,7 +12815,7 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory(llama_cparams & cparams, const llama_memory_params & params) const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; const bool offload = cparams.offload_kqv; diff --git a/src/llama-model.h b/src/llama-model.h index c7dbf829765a1..09cd14cc22c69 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -397,7 +397,7 @@ struct llama_model { // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory(llama_cparams & cparams, const llama_memory_params & params) const; + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph( From 92e626bde86e1cef9bcc5a6936268754e643a325 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Apr 2025 13:28:33 +0300 Subject: [PATCH 18/29] kv-cache : fix recurrent seq_rm() ggml-ci --- src/llama-kv-cache.cpp | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c8df50a66c0c2..1dd6df8a84cd1 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1485,6 +1485,8 @@ void llama_kv_cache_recurrent::clear() { } bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + if (p0 < 0) { p0 = 0; } @@ -1518,6 +1520,34 @@ bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_p } } + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + return true; } From 43cbf38bfe7c69086d1b942f992e6ba1f094f8a7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 10:11:58 +0300 Subject: [PATCH 19/29] kv-cache : replace `struct callbacks` with `llama_model &` ggml-ci --- src/llama-context.cpp | 1 - src/llama-kv-cache.cpp | 55 ++++++++++++++++++++---------- src/llama-kv-cache.h | 41 +++++++++-------------- src/llama-model.cpp | 76 ++++++++++++++---------------------------- src/llama-model.h | 2 ++ 5 files changed, 80 insertions(+), 95 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 3bc9fbf0214a3..e747ed69a325c 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -440,7 +440,6 @@ void llama_context::kv_self_update() { llama_kv_cache * kv_self = static_cast(memory.get()); need_reserve = kv_self->update({ - /*.arch =*/ model.arch, /*.cparams =*/ cparams, /*.sched =*/ sched.get(), /*.backends =*/ backends, diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1dd6df8a84cd1..c796437ccbe34 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -22,13 +22,13 @@ uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { } llama_kv_cache_unified::llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t padding) : cbs(std::move(cbs)), hparams(hparams), v_trans(v_trans), padding(padding) { + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; has_shift = false; @@ -81,7 +81,18 @@ llama_kv_cache_unified::llama_kv_cache_unified( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (!offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -588,7 +599,6 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( float freq_base, float freq_scale, ggml_backend_buffer * bbuf) const { - const auto & arch = params.arch; const auto & cparams = params.cparams; const auto & backends = params.backends; const auto & sched = params.sched; @@ -604,7 +614,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; ggml_tensor * tmp; @@ -697,7 +707,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - ggml_tensor * rope_factors = cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor * k = ggml_view_3d(ctx, k_l[il], @@ -1377,11 +1387,11 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell // llama_kv_cache_recurrent::llama_kv_cache_recurrent( - const llama_hparams & hparams, - callbacks cbs, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size) : cbs(std::move(cbs)), hparams(hparams) { + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size) : hparams(model.hparams) { const int32_t n_layer = hparams.n_layer; LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", @@ -1429,7 +1439,18 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - ggml_backend_buffer_type_t buft = this->cbs.get_buft(i); + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (!offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 95806ae27b275..ef63aaf21febe 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -15,19 +15,10 @@ struct llama_cparams; struct llama_hparams; struct llama_ubatch; struct llama_sbatch; +struct llama_model; struct llama_kv_cache : public llama_memory_i { - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; - - // get the buffer type of layer il, can be used to offload KV cache layers to a different device - std::function get_buft; - }; - struct graph_params { - const llm_arch arch; - const llama_cparams & cparams; const ggml_backend_sched_t & sched; @@ -139,13 +130,13 @@ class llama_kv_cache_unified : public llama_kv_cache { static uint32_t get_padding(const llama_cparams & cparams); llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs, - ggml_type type_k, - ggml_type type_v, - bool v_trans, - uint32_t kv_size, - uint32_t padding); + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool v_trans, + bool offload, + uint32_t kv_size, + uint32_t padding); ~llama_kv_cache_unified() = default; @@ -208,14 +199,13 @@ class llama_kv_cache_unified : public llama_kv_cache { // computed before each graph build uint32_t n = 0; - callbacks cbs; - std::vector cells; std::vector k_l; // per layer std::vector v_l; private: + const llama_model & model; const llama_hparams & hparams; bool has_shift = false; @@ -312,11 +302,11 @@ class llama_kv_cache_recurrent : public llama_kv_cache { }; llama_kv_cache_recurrent( - const llama_hparams & hparams, - callbacks cbs, - ggml_type type_k, - ggml_type type_v, - uint32_t kv_size); + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size); ~llama_kv_cache_recurrent() = default; @@ -370,8 +360,6 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; - callbacks cbs; - // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_impl also uses it, so it // cannot be freely changed after a slot has been allocated. @@ -388,6 +376,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache { std::vector v_l; private: + //const llama_model & model; const llama_hparams & hparams; // commit/restore cache diff --git a/src/llama-model.cpp b/src/llama-model.cpp index f02abb9ab9818..e69601949ed51 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4445,6 +4445,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } +ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const { + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -4485,7 +4498,7 @@ struct llm_build_llama : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4710,7 +4723,7 @@ struct llm_build_deci : public llm_graph_context { } else if (n_head > 0) { // self-attention // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -7192,7 +7205,7 @@ struct llm_build_phi3 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor* attn_norm_output = build_norm(inpL, model.layers[il].attn_norm, @@ -7944,7 +7957,7 @@ struct llm_build_minicpm3 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // norm cur = build_norm(inpL, @@ -9012,7 +9025,7 @@ struct llm_build_cohere2 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9950,7 +9963,7 @@ struct llm_build_deepseek : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11314,7 +11327,7 @@ struct llm_build_exaone : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12695,7 +12708,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12818,28 +12831,6 @@ struct llm_build_bailingmoe : public llm_graph_context { llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; - const bool offload = cparams.offload_kqv; - - auto get_buft = [this, offload](int il) { - const char * dev_name = "CPU"; - - ggml_backend_buffer_type_t buft; - if (offload) { - auto * dev = dev_layer(il); - buft = ggml_backend_dev_buffer_type(dev); - - dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); - } - - LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", il, dev_name); - - return buft; - }; - - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - switch (arch) { case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: @@ -12848,13 +12839,10 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_ARWKV7: { res = new llama_kv_cache_recurrent( - hparams, - { - /*.get_rope_factors =*/ nullptr, - /*.get_buft =*/ get_buft, - }, + *this, GGML_TYPE_F32, GGML_TYPE_F32, + cparams.offload_kqv, std::max((uint32_t) 1, cparams.n_seq_max)); } break; default: @@ -12866,25 +12854,11 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); res = new llama_kv_cache_unified( - hparams, - { - /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } - - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } - - return layers[il].rope_short; - }, - /*.get_buft =*/ get_buft, - }, + *this, params.type_k, params.type_v, !cparams.flash_attn, + cparams.offload_kqv, cparams.n_ctx, padding); } diff --git a/src/llama-model.h b/src/llama-model.h index 09cd14cc22c69..4c7e7a335b4e2 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -395,6 +395,8 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const; + // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; From 66198324db481a6229d8a4bab9b71c551ae55bbc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 11:21:22 +0300 Subject: [PATCH 20/29] kv-cache : replace `struct graph_params` with `llama_context &` ggml-ci --- src/llama-context.cpp | 26 ++++++++++----- src/llama-context.h | 21 ++++++++---- src/llama-kv-cache.cpp | 74 ++++++++++++++++++++++-------------------- src/llama-kv-cache.h | 50 +++++++++------------------- 4 files changed, 86 insertions(+), 85 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index e747ed69a325c..b2e1a132dd7a0 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -396,6 +396,22 @@ const llama_model & llama_context::get_model() const { return model; } +const llama_cparams & llama_context::get_cparams() const { + return cparams; +} + +const ggml_backend_sched_ptr & llama_context::get_sched() const { + return sched; +} + +const ggml_context_ptr & llama_context::get_ctx_compute() const { + return ctx_compute; +} + +const std::vector & llama_context::get_backends() const { + return backends; +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -439,15 +455,7 @@ void llama_context::kv_self_update() { llama_kv_cache * kv_self = static_cast(memory.get()); - need_reserve = kv_self->update({ - /*.cparams =*/ cparams, - /*.sched =*/ sched.get(), - /*.backends =*/ backends, - /*.n_max_nodes =*/ graph_max_nodes(), - /*.get_ctx_compute =*/ [this]() { return ctx_compute.get(); }, - /*.graph_init =*/ [this]() { return graph_init(); }, - /*.graph_compute =*/ [this](ggml_cgraph * gf) { graph_compute(gf, false); }, - }); + need_reserve = kv_self->update(*this); // reserve a worst case graph if needed if (need_reserve) { diff --git a/src/llama-context.h b/src/llama-context.h index 3e79cff8250bf..5aabed2fe15c8 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -27,7 +27,14 @@ struct llama_context { void synchronize(); - const llama_model & get_model() const; + const llama_model & get_model() const; + const llama_cparams & get_cparams() const; + + const ggml_backend_sched_ptr & get_sched() const; + + const ggml_context_ptr & get_ctx_compute() const; + + const std::vector & get_backends() const; uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; @@ -141,22 +148,24 @@ struct llama_context { // graph // +public: int32_t graph_max_nodes() const; // zero-out inputs and create the ctx_compute for the compute graph ggml_cgraph * graph_init(); + // returns the result of ggml_backend_sched_graph_compute_async execution + ggml_status graph_compute( + ggml_cgraph * gf, + bool batched); + +private: llm_graph_result_ptr graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype); - // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); - llm_graph_cb graph_get_cb() const; // TODO: read/write lora adapters and cvec diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c796437ccbe34..0de32e278e6e4 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -4,6 +4,7 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model.h" +#include "llama-context.h" #include #include @@ -367,10 +368,10 @@ void llama_kv_cache_unified::commit() { pending.ranges.clear(); } -bool llama_kv_cache_unified::update(const graph_params & params) { +bool llama_kv_cache_unified::update(llama_context & lctx) { bool need_reserve = false; - const auto & sched = params.sched; + const auto & sched = lctx.get_sched(); if (has_shift) { if (!get_can_shift()) { @@ -381,17 +382,17 @@ bool llama_kv_cache_unified::update(const graph_params & params) { // apply K-shift if needed if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched); + ggml_backend_sched_reset(sched.get()); - auto * gf = params.graph_init(); + auto * gf = lctx.graph_init(); - auto res = build_graph_shift(params, gf); + auto res = build_graph_shift(lctx, gf); - ggml_backend_sched_alloc_graph(sched, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); res->set_inputs(nullptr); - params.graph_compute(gf); + lctx.graph_compute(gf, false); need_reserve = true; } @@ -408,18 +409,18 @@ bool llama_kv_cache_unified::update(const graph_params & params) { if (do_defrag) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - if (defrag_prepare(params.n_max_nodes)) { - ggml_backend_sched_reset(sched); + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched.get()); - auto * gf = params.graph_init(); + auto * gf = lctx.graph_init(); - auto res = build_graph_defrag(params, gf); + auto res = build_graph_defrag(lctx, gf); - ggml_backend_sched_alloc_graph(sched, gf); + ggml_backend_sched_alloc_graph(sched.get(), gf); res->set_inputs(nullptr); - params.graph_compute(gf); + lctx.graph_compute(gf, false); need_reserve = true; } @@ -591,17 +592,17 @@ size_t llama_kv_cache_unified::size_v_bytes() const { } ggml_tensor * llama_kv_cache_unified::build_rope_shift( - const graph_params & params, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const { - const auto & cparams = params.cparams; - const auto & backends = params.backends; - const auto & sched = params.sched; + llama_context & lctx, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const { + const auto & cparams = lctx.get_cparams(); + const auto & backends = lctx.get_backends(); + const auto & sched = lctx.get_sched(); const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; @@ -622,11 +623,12 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); + // TODO: can we simplify/avoid this? if (bbuf) { for (const auto & backend : backends) { // Figure out which backend KV cache belongs to if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { - ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get()); + ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get()); break; } } @@ -674,13 +676,13 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - const graph_params & params, - ggml_cgraph * gf) const { + llama_context & lctx, + ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = params.get_ctx_compute(); + auto * ctx = lctx.get_ctx_compute().get(); - const auto & cparams = params.cparams; + const auto & cparams = lctx.get_cparams(); const auto & n_layer = hparams.n_layer; @@ -716,7 +718,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_row_size(k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(params, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer); + ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer); ggml_build_forward_expand(gf, cur); } @@ -727,15 +729,15 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - const graph_params & params, - ggml_cgraph * gf) const { + llama_context & lctx, + ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = params.get_ctx_compute(); + auto * ctx = lctx.get_ctx_compute().get(); const auto & ids = defrag_info.ids; - const auto & cparams = params.cparams; + const auto & cparams = lctx.get_cparams(); #if 0 // CPU defrag @@ -1725,8 +1727,8 @@ void llama_kv_cache_recurrent::commit() { pending.ranges.clear(); } -bool llama_kv_cache_recurrent::update(const graph_params & params) { - GGML_UNUSED(params); +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); return false; } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index ef63aaf21febe..86841d3fc002c 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -7,7 +7,6 @@ #include "ggml-cpp.h" -#include #include #include @@ -16,26 +15,9 @@ struct llama_hparams; struct llama_ubatch; struct llama_sbatch; struct llama_model; +struct llama_context; struct llama_kv_cache : public llama_memory_i { - struct graph_params { - const llama_cparams & cparams; - - const ggml_backend_sched_t & sched; - - const std::vector & backends; - - int32_t n_max_nodes; - - std::function get_ctx_compute; - - // function for creating ggml graphs - std::function graph_init; - - // function for computing ggml graphs - std::function graph_compute; - }; - virtual ~llama_kv_cache() = default; // call if batch processing fails - restores the cache state @@ -46,7 +28,7 @@ struct llama_kv_cache : public llama_memory_i { // process any pending defrag/shift/etc. operations // optionally call once before processing a new batch - virtual bool update(const graph_params & params) = 0; + virtual bool update(llama_context & lctx) = 0; // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing virtual void defrag_sched(float thold) = 0; @@ -161,7 +143,7 @@ class llama_kv_cache_unified : public llama_kv_cache { void restore() override; void commit() override; - bool update(const graph_params & params) override; + bool update(llama_context & ctx) override; void defrag_sched(float thold) override; @@ -251,22 +233,22 @@ class llama_kv_cache_unified : public llama_kv_cache { size_t size_v_bytes() const; ggml_tensor * build_rope_shift( - const graph_params & params, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const; + llama_context & lctx, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale, + ggml_backend_buffer * bbuf) const; llm_graph_result_ptr build_graph_shift( - const graph_params & params, - ggml_cgraph * gf) const; + llama_context & lctx, + ggml_cgraph * gf) const; llm_graph_result_ptr build_graph_defrag( - const graph_params & params, - ggml_cgraph * gf) const; + llama_context & lctx, + ggml_cgraph * gf) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -331,7 +313,7 @@ class llama_kv_cache_recurrent : public llama_kv_cache { void restore() override; void commit() override; - bool update(const graph_params & params) override; + bool update(llama_context & lctx) override; void defrag_sched(float thold) override; From 95a9f8b59ae3dbd681447737c3ba6fad6b0bac25 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 15:50:27 +0300 Subject: [PATCH 21/29] kv-cache : fix offload check ggml-ci --- src/llama-kv-cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 0de32e278e6e4..c0b232f59b6a2 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -86,7 +86,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - if (!offload) { + if (offload) { auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); @@ -1445,7 +1445,7 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); - if (!offload) { + if (offload) { auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); From 8737e655a931ac9ceeacb88d06e0df66e0d5e7d0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 15:55:15 +0300 Subject: [PATCH 22/29] context : avoid passing unique_ptr ggml-ci --- src/llama-context.cpp | 8 ++++---- src/llama-context.h | 5 +++-- src/llama-kv-cache.cpp | 21 +++++++++++---------- 3 files changed, 18 insertions(+), 16 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index b2e1a132dd7a0..7ca8367c93f8d 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -400,12 +400,12 @@ const llama_cparams & llama_context::get_cparams() const { return cparams; } -const ggml_backend_sched_ptr & llama_context::get_sched() const { - return sched; +ggml_backend_sched_t llama_context::get_sched() const { + return sched.get(); } -const ggml_context_ptr & llama_context::get_ctx_compute() const { - return ctx_compute; +ggml_context * llama_context::get_ctx_compute() const { + return ctx_compute.get(); } const std::vector & llama_context::get_backends() const { diff --git a/src/llama-context.h b/src/llama-context.h index 5aabed2fe15c8..5bb081c26ed4f 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -30,10 +30,11 @@ struct llama_context { const llama_model & get_model() const; const llama_cparams & get_cparams() const; - const ggml_backend_sched_ptr & get_sched() const; + ggml_backend_sched_t get_sched() const; - const ggml_context_ptr & get_ctx_compute() const; + ggml_context * get_ctx_compute() const; + // TODO: this method might be possible to avoid (seach for TAG_BACKENDS) const std::vector & get_backends() const; uint32_t n_ctx() const; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c0b232f59b6a2..2401254404ccc 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -371,7 +371,7 @@ void llama_kv_cache_unified::commit() { bool llama_kv_cache_unified::update(llama_context & lctx) { bool need_reserve = false; - const auto & sched = lctx.get_sched(); + auto * sched = lctx.get_sched(); if (has_shift) { if (!get_can_shift()) { @@ -382,13 +382,13 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { // apply K-shift if needed if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_reset(sched); auto * gf = lctx.graph_init(); auto res = build_graph_shift(lctx, gf); - ggml_backend_sched_alloc_graph(sched.get(), gf); + ggml_backend_sched_alloc_graph(sched, gf); res->set_inputs(nullptr); @@ -410,13 +410,13 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); if (defrag_prepare(lctx.graph_max_nodes())) { - ggml_backend_sched_reset(sched.get()); + ggml_backend_sched_reset(sched); auto * gf = lctx.graph_init(); auto res = build_graph_defrag(lctx, gf); - ggml_backend_sched_alloc_graph(sched.get(), gf); + ggml_backend_sched_alloc_graph(sched, gf); res->set_inputs(nullptr); @@ -602,7 +602,8 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( ggml_backend_buffer * bbuf) const { const auto & cparams = lctx.get_cparams(); const auto & backends = lctx.get_backends(); - const auto & sched = lctx.get_sched(); + + auto * sched = lctx.get_sched(); const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; @@ -623,12 +624,12 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - // TODO: can we simplify/avoid this? + // TODO: can we simplify/avoid this? [TAG_BACKENDS] if (bbuf) { for (const auto & backend : backends) { // Figure out which backend KV cache belongs to if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { - ggml_backend_sched_set_tensor_backend(sched.get(), tmp, backend.get()); + ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get()); break; } } @@ -680,7 +681,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = lctx.get_ctx_compute().get(); + auto * ctx = lctx.get_ctx_compute(); const auto & cparams = lctx.get_cparams(); @@ -733,7 +734,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = lctx.get_ctx_compute().get(); + auto * ctx = lctx.get_ctx_compute(); const auto & ids = defrag_info.ids; From c9bddfc0c89b7922f225ccb0a8ad1ec20bd05aa8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 16:03:25 +0300 Subject: [PATCH 23/29] kv-cache : avoid using the backends from the llama_context ref #13113 ggml-ci --- src/llama-context.cpp | 4 ---- src/llama-context.h | 3 --- src/llama-kv-cache.cpp | 19 ++----------------- src/llama-kv-cache.h | 3 +-- 4 files changed, 3 insertions(+), 26 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7ca8367c93f8d..a88b9a5ff90da 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -408,10 +408,6 @@ ggml_context * llama_context::get_ctx_compute() const { return ctx_compute.get(); } -const std::vector & llama_context::get_backends() const { - return backends; -} - uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } diff --git a/src/llama-context.h b/src/llama-context.h index 5bb081c26ed4f..cf41ac57b9fba 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -34,9 +34,6 @@ struct llama_context { ggml_context * get_ctx_compute() const; - // TODO: this method might be possible to avoid (seach for TAG_BACKENDS) - const std::vector & get_backends() const; - uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; uint32_t n_batch() const; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 2401254404ccc..741742608cfd4 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -598,12 +598,8 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const { + float freq_scale) const { const auto & cparams = lctx.get_cparams(); - const auto & backends = lctx.get_backends(); - - auto * sched = lctx.get_sched(); const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; @@ -624,17 +620,6 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - // TODO: can we simplify/avoid this? [TAG_BACKENDS] - if (bbuf) { - for (const auto & backend : backends) { - // Figure out which backend KV cache belongs to - if (ggml_backend_supports_buft(backend.get(), ggml_backend_buffer_get_type(bbuf))) { - ggml_backend_sched_set_tensor_backend(sched, tmp, backend.get()); - break; - } - } - } - tmp = ggml_rope_ext_inplace(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); @@ -719,7 +704,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_row_size(k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l, k_l[il]->buffer); + ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); ggml_build_forward_expand(gf, cur); } diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 86841d3fc002c..d36cd78ea3f67 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -239,8 +239,7 @@ class llama_kv_cache_unified : public llama_kv_cache { ggml_tensor * shift, ggml_tensor * factors, float freq_base, - float freq_scale, - ggml_backend_buffer * bbuf) const; + float freq_scale) const; llm_graph_result_ptr build_graph_shift( llama_context & lctx, From 09195eb2ee0de9a61752cd81e82311bf873ed7d9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 16:09:39 +0300 Subject: [PATCH 24/29] kv-cache : more consistent debug logs [no ci] --- src/llama-kv-cache.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 741742608cfd4..1d794a2fcc4fe 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -93,7 +93,7 @@ llama_kv_cache_unified::llama_kv_cache_unified( dev_name = ggml_backend_dev_name(dev); } - LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -992,9 +992,9 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return false; } - LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer); + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); return true; } @@ -1438,7 +1438,7 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent( dev_name = ggml_backend_dev_name(dev); } - LLAMA_LOG_DEBUG("layer %3d: dev = %s\n", i, dev_name); + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { From 58e1d40f803f2dffbd7216e0e611e3ac4f6c3363 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 30 Apr 2025 16:13:12 +0300 Subject: [PATCH 25/29] kv-cache : do not pass the full llama_context for kv graphs ggml-ci --- src/llama-kv-cache.cpp | 40 ++++++++++++++++------------------------ src/llama-kv-cache.h | 24 +++++++++++++----------- 2 files changed, 29 insertions(+), 35 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 1d794a2fcc4fe..5842c18d3d7fe 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -386,7 +386,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); - auto res = build_graph_shift(lctx, gf); + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); ggml_backend_sched_alloc_graph(sched, gf); @@ -414,7 +414,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) { auto * gf = lctx.graph_init(); - auto res = build_graph_defrag(lctx, gf); + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); ggml_backend_sched_alloc_graph(sched, gf); @@ -592,15 +592,13 @@ size_t llama_kv_cache_unified::size_v_bytes() const { } ggml_tensor * llama_kv_cache_unified::build_rope_shift( - llama_context & lctx, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & cparams = lctx.get_cparams(); - + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; const auto & yarn_ext_factor = cparams.yarn_ext_factor; @@ -662,14 +660,11 @@ void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { } llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( - llama_context & lctx, - ggml_cgraph * gf) const { + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = lctx.get_ctx_compute(); - - const auto & cparams = lctx.get_cparams(); - const auto & n_layer = hparams.n_layer; const auto & n_embd_head_k = hparams.n_embd_head_k; @@ -704,7 +699,7 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( ggml_row_size(k_l[il]->type, n_embd_k_gqa), 0); - ggml_tensor * cur = build_rope_shift(lctx, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); ggml_build_forward_expand(gf, cur); } @@ -715,16 +710,13 @@ llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( } llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( - llama_context & lctx, - ggml_cgraph * gf) const { + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { auto res = std::make_unique(); - auto * ctx = lctx.get_ctx_compute(); - const auto & ids = defrag_info.ids; - const auto & cparams = lctx.get_cparams(); - #if 0 // CPU defrag // diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index d36cd78ea3f67..bf3b4b6a4430f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -233,21 +233,23 @@ class llama_kv_cache_unified : public llama_kv_cache { size_t size_v_bytes() const; ggml_tensor * build_rope_shift( - llama_context & lctx, - ggml_context * ctx, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; llm_graph_result_ptr build_graph_shift( - llama_context & lctx, - ggml_cgraph * gf) const; + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; llm_graph_result_ptr build_graph_defrag( - llama_context & lctx, - ggml_cgraph * gf) const; + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; From 903e46f1e49c797da257398026ff67235bd45e3b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 May 2025 13:26:30 +0300 Subject: [PATCH 26/29] kv-cache : remove comment --- src/llama-kv-cache.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 5842c18d3d7fe..c535abb41393f 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -216,11 +216,6 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { uint32_t new_head = size; for (uint32_t i = 0; i < size; ++i) { - // TODO: remove tail - //if (recurrent && (llama_seq_id) i != seq_id) { - // cells[i].tail = -1; - //} - if (!cells[i].has_seq_id(seq_id)) { if (cells[i].pos >= 0) { used--; From 00cde5fe4bc1560a63c7a9211f3d5e7a9ccea6c9 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 May 2025 13:26:49 +0300 Subject: [PATCH 27/29] kv-cache : ggml_rope_ext_inplace -> ggml_rope_ext ggml-ci --- src/llama-kv-cache.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index c535abb41393f..a767c050f0dd8 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -613,7 +613,7 @@ ggml_tensor * llama_kv_cache_unified::build_rope_shift( // dequantize to f32 -> RoPE -> quantize back tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - tmp = ggml_rope_ext_inplace(ctx, tmp, + tmp = ggml_rope_ext(ctx, tmp, shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); From 7e79a4271df836aa72066020666c18a4a970de09 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 May 2025 13:27:46 +0300 Subject: [PATCH 28/29] kv-cache : fix recurrent multi-user case ggml-ci --- src/llama-kv-cache.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index a767c050f0dd8..3dcad65bb6a85 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -1936,7 +1936,7 @@ int32_t llama_kv_cache_recurrent::s_copy(int i) const { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[i]); + kv_cell & cell = const_cast(cells[cell_id]); // prevent out-of-bound sources if (cell.src < 0 || (uint32_t) cell.src >= size) { @@ -1959,7 +1959,7 @@ float llama_kv_cache_recurrent::s_mask(int i) const { ////////////////////////////////////////////// // TODO: this should not mutate the KV cache ! - kv_cell & cell = const_cast(cells[i]); + kv_cell & cell = const_cast(cells[cell_id]); float res = (float) (cell.src >= 0); From 5883c9060f337772d8d6bde2cc71f18a34924fe8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 May 2025 16:49:51 +0300 Subject: [PATCH 29/29] memory : remove comments [no ci] --- src/llama-memory.h | 7 +------ 1 file changed, 1 insertion(+), 6 deletions(-) diff --git a/src/llama-memory.h b/src/llama-memory.h index 4a8c396529236..c7412d5911ed7 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -7,12 +7,7 @@ struct llama_memory_params { ggml_type type_k; ggml_type type_v; - //bool v_trans; - //bool offload_kqv; - - //uint32_t kv_size; - - // other types of memory + // parameters for other types of memory // ... };