diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..a88b2fe3082c9 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -189,7 +189,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +llama_sbatch::llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -203,6 +203,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim for (size_t i = 0; i < n_tokens; ++i) { ids[i] = i; } + if (simple_split) { seq.resize(1); llama_sbatch_seq & s = seq[0]; @@ -212,6 +213,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim s.length = n_tokens; return; } + std::sort(ids.begin(), ids.end(), [&batch](size_t a, size_t b) { int32_t n_seq_a = batch.n_seq_id ? batch.n_seq_id[a] : 1; @@ -239,6 +241,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim return n_seq_a > n_seq_b; } ); + // init seq llama_sbatch_seq * last_seq = nullptr; @@ -262,6 +265,7 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim seq.push_back(new_seq); last_seq = &seq.back(); } + // keep shared prompts first at the end, then sort by length descending. std::sort(seq.begin(), seq.end(), [](llama_sbatch_seq & a, llama_sbatch_seq & b) { diff --git a/src/llama-batch.h b/src/llama-batch.h index f1df40d27086e..6305051b62b79 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -70,7 +70,8 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + llama_sbatch() = default; + llama_sbatch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5a2eef9b784a1..a88b9a5ff90da 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -6,11 +6,9 @@ #include "llama-model.h" #include "llama-kv-cache.h" -#include #include #include #include -#include // // llama_context @@ -177,44 +175,13 @@ llama_context::llama_context( } // init the memory module - // TODO: for now, always create a unified KV cache if (!hparams.vocab_only) { - kv_self.reset(static_cast(model.create_memory())); + llama_memory_params params_mem = { + /*.type_k =*/ params.type_k, + /*.type_v =*/ params.type_v, + }; - LLAMA_LOG_DEBUG("%s: n_ctx = %u\n", __func__, cparams.n_ctx); - - cparams.n_ctx = GGML_PAD(cparams.n_ctx, kv_self->get_padding(cparams)); - - LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); - - uint32_t kv_size = cparams.n_ctx; - ggml_type type_k = params.type_k; - ggml_type type_v = params.type_v; - - if (llama_model_is_recurrent(&model)) { - // Mamba needs at least as many KV cells as there are sequences kept at any time - kv_size = std::max((uint32_t) 1, params.n_seq_max); - // it's probably best to keep as much precision as possible for the states - type_k = GGML_TYPE_F32; // required by ggml_ssm_conv for Mamba's conv_states - type_v = GGML_TYPE_F32; // required by ggml_ssm_scan for Mamba's ssm_states - } - - GGML_ASSERT(hparams.n_embd_head_k % ggml_blck_size(type_k) == 0); - GGML_ASSERT(hparams.n_embd_head_v % ggml_blck_size(type_v) == 0); - - if (!kv_self->init(model, cparams, type_k, type_v, kv_size, cparams.offload_kqv)) { - throw std::runtime_error("failed to initialize self-attention cache"); - } - - { - const size_t memory_size_k = kv_self->size_k_bytes(); - const size_t memory_size_v = kv_self->size_v_bytes(); - - LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, - (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), - ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), - ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); - } + memory.reset(model.create_memory(params_mem, cparams)); } // init backends @@ -305,7 +272,9 @@ llama_context::llama_context( int n_nodes_tg = -1; // simulate full KV cache - kv_self->n = kv_self->size; + llama_kv_cache * kv_self = static_cast(memory.get()); + + kv_self->set_full(); cross.v_embd.clear(); @@ -427,6 +396,18 @@ const llama_model & llama_context::get_model() const { return model; } +const llama_cparams & llama_context::get_cparams() const { + return cparams; +} + +ggml_backend_sched_t llama_context::get_sched() const { + return sched.get(); +} + +ggml_context * llama_context::get_ctx_compute() const { + return ctx_compute.get(); +} + uint32_t llama_context::n_ctx() const { return cparams.n_ctx; } @@ -456,337 +437,21 @@ uint32_t llama_context::n_threads_batch() const { } llama_kv_cache * llama_context::get_kv_self() { - return kv_self.get(); + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } const llama_kv_cache * llama_context::get_kv_self() const { - return kv_self.get(); -} - -ggml_tensor * llama_context::build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const { - const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - - const auto & yarn_ext_factor = cparams.yarn_ext_factor; - const auto & yarn_beta_fast = cparams.yarn_beta_fast; - const auto & yarn_beta_slow = cparams.yarn_beta_slow; - - const auto & hparams = model.hparams; - - const auto & n_rot = hparams.n_rot; - const auto & rope_type = hparams.rope_type; - - // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. - // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. - const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - - ggml_tensor * tmp; - - if (ggml_is_quantized(cur->type)) { - // dequantize to f32 -> RoPE -> quantize back - tmp = ggml_cast(ctx0, cur, GGML_TYPE_F32); - - tmp = ggml_rope_ext(ctx0, tmp, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - - tmp = ggml_cpy(ctx0, tmp, cur); - } else { - // we rotate only the first n_rot dimensions - tmp = ggml_rope_ext_inplace(ctx0, cur, - shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, - yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - } - - return tmp; -} - -class llm_graph_input_k_shift : public llm_graph_input_i { -public: - llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} - virtual ~llm_graph_input_k_shift() = default; - - void set_input(const llama_ubatch * ubatch) override; - - ggml_tensor * k_shift; // I32 [kv_size] - - const llama_kv_cache_unified * kv_self; -}; - -void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - - if (k_shift) { - assert(ggml_backend_buffer_is_host(k_shift->buffer)); - - int32_t * data = (int32_t *) k_shift->data; - - for (uint32_t i = 0; i < kv_self->size; ++i) { - data[i] = kv_self->cells[i].delta; - } - } -} - -llm_graph_result_ptr llama_context::build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & n_layer = hparams.n_layer; - - const auto & n_embd_head_k = hparams.n_embd_head_k; - //const auto & n_embd_head_v = hparams.n_embd_head_v; - - //GGML_ASSERT(kv_self->size == n_ctx); - - auto inp = std::make_unique(kv_self.get()); - - inp->k_shift = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, cparams.n_ctx); - ggml_set_input(inp->k_shift); - - for (uint32_t il = 0; il < n_layer; ++il) { - const int64_t n_head_kv = hparams.n_head_kv(il); - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - - const bool is_swa = hparams.is_swa(il); - - // note: the swa rope params could become part of the cparams in the future - // if we decide to make them configurable, like the non-sliding ones - const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; - const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; - - ggml_tensor * rope_factors = kv_self->cbs.get_rope_factors(n_ctx_per_seq(), il); - - ggml_tensor * k = - ggml_view_3d(ctx0, kv_self->k_l[il], - n_embd_head_k, n_head_kv, kv_self->size, - ggml_row_size(kv_self->k_l[il]->type, n_embd_head_k), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - 0); - - ggml_tensor * cur = build_rope_shift(ctx0, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); - - ggml_build_forward_expand(gf, cur); - } - - res->add_input(std::move(inp)); - - return res; -} - -llm_graph_result_ptr llama_context::build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const { - auto res = std::make_unique(); - - const auto & hparams = model.hparams; - - const auto & ids = kv_self->defrag_info.ids; - -#if 0 - // CPU defrag - // - // TODO: optimizations are possible: - // - multiple threads - // - avoid copying to the host memory when already there - // - // likely not worth the effort, as we have ggml_graph based defrag - // - - const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); - const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); - - const uint32_t kv_size = size; - - std::vector buf_k; - std::vector buf_v; - - for (uint32_t il = 0; il < n_layer; ++il) { - const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); - const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); - - const size_t v_size_el = ggml_type_size(v_l[il]->type); - const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); - - buf_k.resize(k_size); - buf_v.resize(v_size); - - ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); - - // batch move [i, i+nm) to [id, id+nm) - // note: cells can move only to a lower index - for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == n_kv) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < n_kv && ids[i + nm] == id + nm) { - nm++; - } - - // move keys - { - const int64_t os = i*k_size_row; - const int64_t od = id*k_size_row; - - memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); - } - - // move values (note: they are transposed) - { - const int64_t os = i; - const int64_t od = id; - - for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { - memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); - } - } - - i += nm - 1; - } - - ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); - ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); - } -#else - for (uint32_t i = 0; i < ids.size(); ++i) { - const uint32_t id = ids[i]; - - if (i == id || id == ids.size()) { - continue; - } - - uint32_t nm = 1; - - while (i + nm < ids.size() && ids[i + nm] == id + nm) { - nm++; - } - - for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT - const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); - const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); - - ggml_tensor * view_k_src = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*i)); - - ggml_tensor * view_k_dst = ggml_view_2d(ctx0, kv_self->k_l[il], - n_embd_k_gqa, nm, - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa), - ggml_row_size(kv_self->k_l[il]->type, n_embd_k_gqa*id)); - - ggml_tensor * view_v_src; - ggml_tensor * view_v_dst; - - if (cparams.flash_attn) { - // NOTE: the V cache is not transposed when using flash attention - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - n_embd_v_gqa, nm, - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa), - ggml_row_size(kv_self->v_l[il]->type, n_embd_v_gqa*id)); - } else { - view_v_src = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, i)); - - view_v_dst = ggml_view_2d(ctx0, kv_self->v_l[il], - nm, n_embd_v_gqa, - ggml_row_size(kv_self->v_l[il]->type, kv_self->size), - ggml_row_size(kv_self->v_l[il]->type, id)); - } - - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_k_src, view_k_dst)); - ggml_build_forward_expand(gf, ggml_cpy(ctx0, view_v_src, view_v_dst)); - } - - i += nm - 1; - } - - //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); -#endif - - return res; + llama_kv_cache * kv_self = static_cast(memory.get()); + return kv_self; } void llama_context::kv_self_update() { - auto & kv = kv_self; - bool need_reserve = false; - if (kv->has_shift) { - if (!kv->get_can_shift()) { - GGML_ABORT("The current context does not support K-shift"); - } - - LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - - // apply K-shift if needed - if (model.hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_shift(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } - - { - kv->has_shift = false; - - for (uint32_t i = 0; i < kv->size; ++i) { - kv->cells[i].delta = 0; - } - } - } - - // defragment the KV cache if needed - if (kv->do_defrag) { - LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - - if (kv->defrag_prepare(graph_max_nodes())) { - ggml_backend_sched_reset(sched.get()); - - auto * gf = graph_init(); - - auto res = build_kv_self_defrag(ctx_compute.get(), gf); - - ggml_backend_sched_alloc_graph(sched.get(), gf); - - res->set_inputs(nullptr); - - graph_compute(gf, false); - - need_reserve = true; - } + llama_kv_cache * kv_self = static_cast(memory.get()); - kv->do_defrag = false; - } + need_reserve = kv_self->update(*this); // reserve a worst case graph if needed if (need_reserve) { @@ -797,7 +462,7 @@ void llama_context::kv_self_update() { uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); // simulate full KV cache - kv_self->n = kv_self->size; + kv_self->set_full(); llama_token token = model.vocab.token_bos(); // not actually used by llama_build_graph, but required to choose between token and embedding inputs graph llama_ubatch ubatch = { true, n_tokens, n_tokens / n_seqs, n_seqs, &token, nullptr, nullptr, nullptr, nullptr, nullptr}; @@ -818,9 +483,6 @@ enum llama_pooling_type llama_context::pooling_type() const { } float * llama_context::get_logits() { - // reorder logits for backward compatibility - output_reorder(); - return logits; } @@ -863,9 +525,6 @@ float * llama_context::get_logits_ith(int32_t i) { } float * llama_context::get_embeddings() { - // reorder embeddings for backward compatibility - output_reorder(); - return embd; } @@ -1017,8 +676,8 @@ int llama_context::encode(llama_batch & inp_batch) { } // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // note: during encode, we always pass the full sequence starting from pos = 0 + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : 0); const llama_batch & batch = batch_allocr.batch; const int32_t n_tokens = batch.n_tokens; @@ -1047,7 +706,7 @@ int llama_context::encode(llama_batch & inp_batch) { const int64_t n_embd = hparams.n_embd; - sbatch.from_batch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); + llama_sbatch sbatch = llama_sbatch(batch, n_embd, /* simple_split */ true, /* logits_all */ true); const llama_ubatch ubatch = sbatch.split_simple(n_tokens); @@ -1181,9 +840,11 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } + llama_kv_cache * kv_self = static_cast(memory.get()); + // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + // TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1); const llama_batch & batch = batch_allocr.batch; @@ -1195,7 +856,7 @@ int llama_context::decode(llama_batch & inp_batch) { const int64_t n_tokens_all = batch.n_tokens; const int64_t n_embd = hparams.n_embd; - llama_kv_cache_guard kv_guard(kv_self.get()); + llama_kv_cache_guard kv_guard(kv_self); GGML_ASSERT((!batch.token && batch.embd) || (batch.token && !batch.embd)); // NOLINT @@ -1236,11 +897,7 @@ int llama_context::decode(llama_batch & inp_batch) { n_outputs_all = 1; } - const bool logits_all = n_outputs_all == n_tokens_all; - - sbatch.from_batch(batch, n_embd, - /* simple_split */ !kv_self->recurrent, - /* logits_all */ logits_all); + llama_sbatch sbatch = kv_self->sbatch_init(batch, /* logits_all */ n_outputs_all == n_tokens_all); // reserve output buffer if (output_reserve(n_outputs_all) < n_outputs_all) { @@ -1254,22 +911,7 @@ int llama_context::decode(llama_batch & inp_batch) { int64_t n_outputs_prev = 0; while (sbatch.n_tokens > 0) { - llama_ubatch ubatch = llama_ubatch(); - - const auto & n_ubatch = cparams.n_ubatch; - - if (kv_self->recurrent) { - if (embd_pooled) { - // Pooled embeddings cannot be split across ubatches (yet) - ubatch = sbatch.split_seq(cparams.n_ubatch); - } else { - // recurrent model architectures are easier to implement - // with equal-length sequences - ubatch = sbatch.split_equal(cparams.n_ubatch); - } - } else { - ubatch = sbatch.split_simple(n_ubatch); - } + llama_ubatch ubatch = kv_self->ubatch_next(sbatch, cparams.n_ubatch, embd_pooled); // count the outputs in this u_batch { @@ -1289,24 +931,12 @@ int llama_context::decode(llama_batch & inp_batch) { } // find KV slot - { - if (!kv_self->find_slot(ubatch)) { - LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); + if (!kv_self->find_slot(ubatch)) { + LLAMA_LOG_WARN("%s: failed to find KV cache slot for ubatch of size %d\n", __func__, ubatch.n_tokens); - return 1; - } - - if (!kv_self->recurrent) { - // a heuristic, to avoid attending the full cache if it is not yet utilized - // after enough generations, the benefit from this heuristic disappears - // if we start defragmenting the cache, the benefit from this will be more important - const uint32_t pad = kv_self->get_padding(cparams); - kv_self->n = std::min(kv_self->size, std::max(pad, GGML_PAD(kv_self->cell_max(), pad))); - } + return 1; } - //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self->n, kv_self->used, kv_self->head); - ggml_backend_sched_reset(sched.get()); ggml_backend_sched_set_eval_callback(sched.get(), cparams.cb_eval, cparams.cb_eval_user_data); @@ -1424,18 +1054,52 @@ int llama_context::decode(llama_batch & inp_batch) { { bool sorted_output = true; - GGML_ASSERT(sbatch.out_ids.size() == (size_t) n_outputs_all); + auto & out_ids = sbatch.out_ids; + + GGML_ASSERT(out_ids.size() == (size_t) n_outputs_all); for (int64_t i = 0; i < n_outputs_all; ++i) { - int64_t out_id = sbatch.out_ids[i]; + int64_t out_id = out_ids[i]; output_ids[out_id] = i; if (out_id != i) { sorted_output = false; } } - if (sorted_output) { - sbatch.out_ids.clear(); + // make the outputs have the same order they had in the user-provided batch + // note: this is mostly relevant for recurrent models atm + if (!sorted_output) { + const uint32_t n_vocab = model.vocab.n_tokens(); + const uint32_t n_embd = model.hparams.n_embd; + + GGML_ASSERT((size_t) n_outputs == out_ids.size()); + + // TODO: is there something more efficient which also minimizes swaps? + // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) + for (int32_t i = 0; i < n_outputs - 1; ++i) { + int32_t j_min = i; + for (int32_t j = i + 1; j < n_outputs; ++j) { + if (out_ids[j] < out_ids[j_min]) { + j_min = j; + } + } + if (j_min == i) { continue; } + std::swap(out_ids[i], out_ids[j_min]); + if (logits_size > 0) { + for (uint32_t k = 0; k < n_vocab; k++) { + std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); + } + } + if (embd_size > 0) { + for (uint32_t k = 0; k < n_embd; k++) { + std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); + } + } + } + std::fill(output_ids.begin(), output_ids.end(), -1); + for (int32_t i = 0; i < n_outputs; ++i) { + output_ids[out_ids[i]] = i; + } } } @@ -1446,17 +1110,8 @@ int llama_context::decode(llama_batch & inp_batch) { //synchronize(); // decide if we need to defrag the kv cache - if (cparams.causal_attn && cparams.defrag_thold > 0.0f) { - // - do not defrag small contexts (i.e. < 2048 tokens) - // - count the padding towards the number of used tokens - const float fragmentation = kv_self->n >= 2048 ? std::max(0.0f, 1.0f - float(kv_self->used + kv_self->get_padding(cparams))/float(kv_self->n)) : 0.0f; - - // queue defragmentation for next llama_kv_cache_update - if (fragmentation > cparams.defrag_thold) { - LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); - - kv_self->defrag(); - } + if (cparams.defrag_thold > 0.0f) { + kv_self->defrag_sched(cparams.defrag_thold); } // Reset state for the next token before backend sync, to allow the CPU activities in the reset to @@ -1542,44 +1197,6 @@ int32_t llama_context::output_reserve(int32_t n_outputs) { return n_outputs_max; } -void llama_context::output_reorder() { - auto & out_ids = sbatch.out_ids; - if (!out_ids.empty()) { - const uint32_t n_vocab = model.vocab.n_tokens(); - const uint32_t n_embd = model.hparams.n_embd; - - GGML_ASSERT((size_t) n_outputs == out_ids.size()); - - // TODO: is there something more efficient which also minimizes swaps? - // selection sort, to minimize swaps (from https://en.wikipedia.org/wiki/Selection_sort) - for (int32_t i = 0; i < n_outputs - 1; ++i) { - int32_t j_min = i; - for (int32_t j = i + 1; j < n_outputs; ++j) { - if (out_ids[j] < out_ids[j_min]) { - j_min = j; - } - } - if (j_min == i) { continue; } - std::swap(out_ids[i], out_ids[j_min]); - if (logits_size > 0) { - for (uint32_t k = 0; k < n_vocab; k++) { - std::swap(logits[i*n_vocab + k], logits[j_min*n_vocab + k]); - } - } - if (embd_size > 0) { - for (uint32_t k = 0; k < n_embd; k++) { - std::swap(embd[i*n_embd + k], embd[j_min*n_embd + k]); - } - } - } - std::fill(output_ids.begin(), output_ids.end(), -1); - for (int32_t i = 0; i < n_outputs; ++i) { - output_ids[out_ids[i]] = i; - } - out_ids.clear(); - } -} - // // graph // @@ -1616,7 +1233,7 @@ llm_graph_result_ptr llama_context::graph_build( /*.backend_cpu =*/ backend_cpu, /*.cvec =*/ &cvec, /*.loras =*/ &loras, - /*.memory =*/ kv_self.get(), + /*.memory =*/ memory.get(), /*.cross =*/ &cross, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), @@ -2020,8 +1637,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - output_reorder(); - const auto n_outputs = this->n_outputs; const auto & output_ids = this->output_ids; @@ -2075,6 +1690,8 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } LLAMA_LOG_DEBUG("%s: - writing KV self\n", __func__); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_write(io); return io.n_bytes(); @@ -2159,6 +1776,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } LLAMA_LOG_DEBUG("%s: - reading KV self\n", __func__); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_read(io); return io.n_bytes(); @@ -2167,6 +1786,8 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_write(io, seq_id); return io.n_bytes(); @@ -2175,6 +1796,8 @@ size_t llama_context::state_seq_write_data(llama_io_write_i & io, llama_seq_id s size_t llama_context::state_seq_read_data(llama_io_read_i & io, llama_seq_id seq_id) { GGML_UNUSED(seq_id); + llama_kv_cache * kv_self = static_cast(memory.get()); + kv_self->state_read(io, seq_id); return io.n_bytes(); @@ -2530,7 +2153,7 @@ void llama_kv_cache_seq_cp( llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { - return llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); + llama_kv_self_seq_cp(ctx, seq_id_src, seq_id_dst, p0, p1); } void llama_kv_self_seq_cp( @@ -2544,14 +2167,14 @@ void llama_kv_self_seq_cp( return; } - return kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); + kv->seq_cp(seq_id_src, seq_id_dst, p0, p1); } // deprecated void llama_kv_cache_seq_keep( llama_context * ctx, llama_seq_id seq_id) { - return llama_kv_self_seq_keep(ctx, seq_id); + llama_kv_self_seq_keep(ctx, seq_id); } void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { @@ -2560,7 +2183,7 @@ void llama_kv_self_seq_keep(llama_context * ctx, llama_seq_id seq_id) { return; } - return kv->seq_keep(seq_id); + kv->seq_keep(seq_id); } // deprecated @@ -2570,7 +2193,7 @@ void llama_kv_cache_seq_add( llama_pos p0, llama_pos p1, llama_pos delta) { - return llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); + llama_kv_self_seq_add(ctx, seq_id, p0, p1, delta); } void llama_kv_self_seq_add( @@ -2584,7 +2207,7 @@ void llama_kv_self_seq_add( return; } - return kv->seq_add(seq_id, p0, p1, delta); + kv->seq_add(seq_id, p0, p1, delta); } // deprecated @@ -2594,7 +2217,7 @@ void llama_kv_cache_seq_div( llama_pos p0, llama_pos p1, int d) { - return llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); + llama_kv_self_seq_div(ctx, seq_id, p0, p1, d); } void llama_kv_self_seq_div( @@ -2608,7 +2231,7 @@ void llama_kv_self_seq_div( return; } - return kv->seq_div(seq_id, p0, p1, d); + kv->seq_div(seq_id, p0, p1, d); } // deprecated @@ -2627,7 +2250,7 @@ llama_pos llama_kv_self_seq_pos_max(llama_context * ctx, llama_seq_id seq_id) { // deprecated void llama_kv_cache_defrag(llama_context * ctx) { - return llama_kv_self_defrag(ctx); + llama_kv_self_defrag(ctx); } void llama_kv_self_defrag(llama_context * ctx) { @@ -2636,7 +2259,8 @@ void llama_kv_self_defrag(llama_context * ctx) { return; } - return kv->defrag(); + // force defrag + kv->defrag_sched(-1.0f); } // deprecated diff --git a/src/llama-context.h b/src/llama-context.h index 5457f077c15bf..cf41ac57b9fba 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -27,7 +27,12 @@ struct llama_context { void synchronize(); - const llama_model & get_model() const; + const llama_model & get_model() const; + const llama_cparams & get_cparams() const; + + ggml_backend_sched_t get_sched() const; + + ggml_context * get_ctx_compute() const; uint32_t n_ctx() const; uint32_t n_ctx_per_seq() const; @@ -137,49 +142,30 @@ struct llama_context { // Returns max number of outputs for which space was reserved. int32_t output_reserve(int32_t n_outputs); - // make the outputs have the same order they had in the user-provided batch - // TODO: maybe remove this - void output_reorder(); - // // graph // +public: int32_t graph_max_nodes() const; // zero-out inputs and create the ctx_compute for the compute graph ggml_cgraph * graph_init(); + // returns the result of ggml_backend_sched_graph_compute_async execution + ggml_status graph_compute( + ggml_cgraph * gf, + bool batched); + +private: llm_graph_result_ptr graph_build( ggml_context * ctx, ggml_cgraph * gf, const llama_ubatch & ubatch, llm_graph_type gtype); - // returns the result of ggml_backend_sched_graph_compute_async execution - ggml_status graph_compute( - ggml_cgraph * gf, - bool batched); - llm_graph_cb graph_get_cb() const; - // used by kv_self_update() - ggml_tensor * build_rope_shift( - ggml_context * ctx0, - ggml_tensor * cur, - ggml_tensor * shift, - ggml_tensor * factors, - float freq_base, - float freq_scale) const; - - llm_graph_result_ptr build_kv_self_shift( - ggml_context * ctx0, - ggml_cgraph * gf) const; - - llm_graph_result_ptr build_kv_self_defrag( - ggml_context * ctx0, - ggml_cgraph * gf) const; - // TODO: read/write lora adapters and cvec size_t state_write_data(llama_io_write_i & io); size_t state_read_data (llama_io_read_i & io); @@ -196,11 +182,10 @@ struct llama_context { llama_cparams cparams; llama_adapter_cvec cvec; llama_adapter_loras loras; - llama_sbatch sbatch; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably - std::unique_ptr kv_self; + std::unique_ptr memory; // TODO: remove bool logits_all = false; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fabb9ca237653..0da4e7d2b0547 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -284,24 +284,7 @@ void llm_graph_input_s_copy::set_input(const llama_ubatch * ubatch) { // assuming copy destinations ALWAYS happen ONLY on the cells between head and head+n for (uint32_t i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - // prevent out-of-bound sources - if (kv_cell.src < 0 || (uint32_t) kv_cell.src >= kv_self->size) { - kv_cell.src = cell_id; - } - - data[i] = kv_cell.src; - - // TODO: do not mutate the KV cache - // ensure copy only happens once - if (kv_cell.src != (int32_t) cell_id) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_copy(i); } } } @@ -317,18 +300,7 @@ void llm_graph_input_s_mask::set_input(const llama_ubatch * ubatch) { // clear unused states for (int i = 0; i < n_kv; ++i) { - const uint32_t cell_id = i + kv_self->head; - - ////////////////////////////////////////////// - // TODO: this should not mutate the KV cache ! - llama_kv_cell & kv_cell = const_cast(kv_self)->cells[i]; - - data[i] = (float) (kv_cell.src >= 0); - - // only clear once - if (kv_cell.src < 0) { - kv_cell.src = cell_id; - } + data[i] = kv_self->s_mask(i); } } } @@ -1105,7 +1077,7 @@ ggml_tensor * llm_graph_context::build_inp_cls() const { } ggml_tensor * llm_graph_context::build_inp_s_copy() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1122,7 +1094,7 @@ ggml_tensor * llm_graph_context::build_inp_s_copy() const { } ggml_tensor * llm_graph_context::build_inp_s_mask() const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); auto inp = std::make_unique(kv_self); @@ -1436,8 +1408,6 @@ ggml_tensor * llm_graph_context::build_attn( // store to KV cache { - GGML_ASSERT(!kv_self->recurrent); - const auto kv_head = kv_self->head; GGML_ASSERT(kv_self->size == n_ctx); @@ -1587,7 +1557,7 @@ ggml_tensor * llm_graph_context::build_copy_mask_state( ggml_tensor * state_mask, int32_t n_state, int32_t n_seqs) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_kv = kv_self->n; const auto kv_head = kv_self->head; @@ -1619,7 +1589,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_load( ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; @@ -1640,7 +1610,7 @@ ggml_tensor * llm_graph_context::build_rwkv_token_shift_store( ggml_tensor * token_shift, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto token_shift_count = hparams.token_shift_count; const auto n_embd = hparams.n_embd; diff --git a/src/llama-graph.h b/src/llama-graph.h index d0c8d32192784..5b404366dc145 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -19,6 +19,7 @@ struct llama_cparams; class llama_memory_i; class llama_kv_cache_unified; +class llama_kv_cache_recurrent; // certain models (typically multi-modal) can produce different types of graphs enum llm_graph_type { @@ -186,26 +187,26 @@ class llm_graph_input_cls : public llm_graph_input_i { class llm_graph_input_s_copy : public llm_graph_input_i { public: - llm_graph_input_s_copy(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_copy(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_copy() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_copy; // I32 [kv_size] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_s_mask : public llm_graph_input_i { public: - llm_graph_input_s_mask(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + llm_graph_input_s_mask(const llama_kv_cache_recurrent * kv_self) : kv_self(kv_self) {} virtual ~llm_graph_input_s_mask() = default; void set_input(const llama_ubatch * ubatch) override; ggml_tensor * s_mask; // F32 [1, n_kv] - const llama_kv_cache_unified * kv_self; + const llama_kv_cache_recurrent * kv_self; }; class llm_graph_input_cross_embd : public llm_graph_input_i { @@ -350,8 +351,8 @@ struct llm_graph_params { const llama_cparams & cparams; const llama_ubatch & ubatch; - ggml_backend_sched * sched; - ggml_backend * backend_cpu; + ggml_backend_sched_t sched; + ggml_backend_t backend_cpu; const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; @@ -402,9 +403,9 @@ struct llm_graph_context { ggml_context * ctx0 = nullptr; - ggml_backend_sched * sched; + ggml_backend_sched_t sched; - ggml_backend * backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? + ggml_backend_t backend_cpu; // TODO: needed by build_attn_mha, figure out a way to remove? const llama_adapter_cvec * cvec; const llama_adapter_loras * loras; diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index 7c9d46d8119b3..3dcad65bb6a85 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -4,33 +4,41 @@ #include "llama-batch.h" #include "llama-cparams.h" #include "llama-model.h" +#include "llama-context.h" #include #include +#include #include #include #include -llama_kv_cache_unified::llama_kv_cache_unified(const llama_hparams & hparams, callbacks cbs) : hparams(hparams), cbs(std::move(cbs)) { +// +// llama_kv_cache_unified +// + +uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) { + // the FA kernels require padding to avoid extra runtime boundary checks + return cparams.flash_attn ? 256u : 32u; } -bool llama_kv_cache_unified::init( +llama_kv_cache_unified::llama_kv_cache_unified( const llama_model & model, - const llama_cparams & cparams, ggml_type type_k, ggml_type type_v, + bool v_trans, + bool offload, uint32_t kv_size, - bool offload) { + uint32_t padding) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding) { const int32_t n_layer = hparams.n_layer; has_shift = false; + can_shift = true; - recurrent = llama_model_is_recurrent(&model); - v_trans = !recurrent && !cparams.flash_attn; - can_shift = !recurrent; + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d, padding = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift, padding); - LLAMA_LOG_INFO("%s: kv_size = %d, offload = %d, type_k = '%s', type_v = '%s', n_layer = %d, can_shift = %d\n", - __func__, kv_size, offload, ggml_type_name(type_k), ggml_type_name(type_v), n_layer, can_shift); + GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding"); head = 0; size = kv_size; @@ -76,23 +84,20 @@ bool llama_kv_cache_unified::init( const char * dev_name = "CPU"; - ggml_backend_buffer_type_t buft; + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + if (offload) { auto * dev = model.dev_layer(i); buft = ggml_backend_dev_buffer_type(dev); dev_name = ggml_backend_dev_name(dev); - } else { - buft = ggml_backend_cpu_buffer_type(); } - LLAMA_LOG_DEBUG("%s: layer %3d: n_embd_k_gqa = %d, n_embd_v_gqa = %d, dev = %s\n", __func__, - i, n_embd_k_gqa, n_embd_v_gqa, dev_name); + LLAMA_LOG_DEBUG("%s: layer %3d: dev = %s\n", __func__, i, dev_name); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { - LLAMA_LOG_ERROR("%s: failed to create ggml context for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to create ggml context for kv cache"); } ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); @@ -110,55 +115,28 @@ bool llama_kv_cache_unified::init( ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); if (!buf) { - LLAMA_LOG_ERROR("%s: failed to allocate buffer for kv cache\n", __func__); - return false; + throw std::runtime_error("failed to allocate buffer for kv cache"); } ggml_backend_buffer_clear(buf, 0); LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); bufs.emplace_back(buf); } - return true; -} - -int32_t llama_kv_cache_unified::get_n_tokens() const { - int32_t result = 0; - - for (uint32_t i = 0; i < size; i++) { - result += cells[i].seq_id.size(); - } - - return result; -} - -int32_t llama_kv_cache_unified::get_used_cells() const { - return used; -} - -size_t llama_kv_cache_unified::total_size() const { - size_t size = 0; - for (const auto & buf : bufs) { - size += ggml_backend_buffer_get_size(buf.get()); - } - - return size; -} + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); -llama_pos llama_kv_cache_unified::pos_max() const { - llama_pos pos_max = -1; - for (const auto & cell : cells) { - pos_max = std::max(pos_max, cell.pos); + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); } - - return pos_max; } void llama_kv_cache_unified::clear() { for (int32_t i = 0; i < (int32_t) size; ++i) { cells[i].pos = -1; cells[i].seq_id.clear(); - cells[i].src = -1; - cells[i].tail = -1; } head = 0; used = 0; @@ -179,35 +157,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1 = std::numeric_limits::max(); } - // models like Mamba or RWKV can't have a state partially erased - if (recurrent) { - if (seq_id >= (int64_t) size) { - // could be fatal - return false; - } - if (0 <= seq_id) { - int32_t & tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - const llama_kv_cell & cell = cells[tail_id]; - // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { - return false; - } - // invalidate tails which will be cleared - if (p0 <= cell.pos && cell.pos < p1) { - tail_id = -1; - } - } - } else { - // seq_id is negative, then the range should include everything or nothing - if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { - return false; - } - } - - return true; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].pos >= p0 && cells[i].pos < p1) { if (seq_id < 0) { @@ -224,7 +173,6 @@ bool llama_kv_cache_unified::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } cells[i].pos = -1; - cells[i].src = -1; if (new_head == size) { new_head = i; @@ -254,34 +202,6 @@ void llama_kv_cache_unified::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id p1 = std::numeric_limits::max(); } - if (recurrent) { - if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { - llama_kv_cell & tail_src = cells[seq_id_src]; - llama_kv_cell & tail_dst = cells[seq_id_dst]; - if (tail_dst.tail >= 0) { - // clear destination seq_id if it wasn't empty - llama_kv_cell & cell_dst = cells[tail_dst.tail]; - - cell_dst.seq_id.erase(seq_id_dst); - tail_dst.tail = -1; - if (cell_dst.seq_id.empty()) { - cell_dst.pos = -1; - cell_dst.delta = -1; - cell_dst.src = -1; - used -= 1; - } - } - if (tail_src.tail >= 0) { - llama_kv_cell & cell_src = cells[tail_src.tail]; - - cell_src.seq_id.insert(seq_id_dst); - tail_dst.tail = tail_src.tail; - } - } - - return; - } - // otherwise, this is the KV of a Transformer-like model head = 0; @@ -296,17 +216,12 @@ void llama_kv_cache_unified::seq_keep(llama_seq_id seq_id) { uint32_t new_head = size; for (uint32_t i = 0; i < size; ++i) { - if (recurrent && (llama_seq_id) i != seq_id) { - cells[i].tail = -1; - } - if (!cells[i].has_seq_id(seq_id)) { if (cells[i].pos >= 0) { used--; } cells[i].pos = -1; - cells[i].src = -1; cells[i].seq_id.clear(); if (new_head == size){ @@ -344,20 +259,6 @@ void llama_kv_cache_unified::seq_add(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be shifted - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos += delta; - } - } - } - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -400,21 +301,6 @@ void llama_kv_cache_unified::seq_div(llama_seq_id seq_id, llama_pos p0, llama_po return; } - if (recurrent) { - // for Mamba-like or RWKV models, only the pos needs to be changed - if (0 <= seq_id && seq_id < (int64_t) size) { - const int32_t tail_id = cells[seq_id].tail; - if (tail_id >= 0) { - llama_kv_cell & cell = cells[tail_id]; - if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { - cell.pos /= d; - } - } - } - - return; - } - for (uint32_t i = 0; i < size; ++i) { if (cells[i].has_seq_id(seq_id) && cells[i].pos >= p0 && cells[i].pos < p1) { has_shift = true; @@ -440,23 +326,11 @@ llama_pos llama_kv_cache_unified::seq_pos_max(llama_seq_id seq_id) const { return result; } -void llama_kv_cache_unified::defrag() { - if (!recurrent) { - do_defrag = true; - } -} - void llama_kv_cache_unified::restore() { if (pending.ranges.empty()) { return; } - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - seq_rm(-1, -1, -1); - return; - } - uint32_t new_head = size; for (auto & range : pending.ranges) { @@ -469,7 +343,6 @@ void llama_kv_cache_unified::restore() { } cells[i].pos = -1; - cells[i].src = -1; } new_head = std::min(new_head, range.c0); @@ -481,11 +354,6 @@ void llama_kv_cache_unified::restore() { } void llama_kv_cache_unified::commit() { - // TODO: tmp - move to llama_kv_cache_recurrent - if (recurrent) { - return; - } - if (pending.ranges.empty()) { LLAMA_LOG_WARN("%s: no pending KV cache updates to commit - might indicate a bug (ref: %s)\n", __func__, "https://github.com/ggml-org/llama.cpp/pull/12695"); @@ -495,183 +363,110 @@ void llama_kv_cache_unified::commit() { pending.ranges.clear(); } -bool llama_kv_cache_unified::get_can_shift() const { - return can_shift; -} +bool llama_kv_cache_unified::update(llama_context & lctx) { + bool need_reserve = false; -bool llama_kv_cache_unified::find_slot( - const llama_ubatch & ubatch) { - const uint32_t n_tokens = ubatch.n_tokens; - const uint32_t n_seqs = ubatch.n_seqs; - const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + auto * sched = lctx.get_sched(); - // if we have enough unused cells before the current head -> - // better to start searching from the beginning of the cache, hoping to fill it - if (head > used + 2*ubatch.n_tokens) { - head = 0; - } + if (has_shift) { + if (!get_can_shift()) { + GGML_ABORT("The current KV cache / model configuration does not support K-shift"); + } - if (recurrent) { - // For recurrent state architectures (like Mamba or RWKV), - // each cache cell can store the state for a whole sequence. - // A slot should be always be contiguous. + LLAMA_LOG_DEBUG("%s: applying K-shift\n", __func__); - // can only process batches with an equal number of new tokens in each sequence - GGML_ASSERT(ubatch.equal_seqs); + // apply K-shift if needed + if (hparams.rope_type != LLAMA_ROPE_TYPE_NONE) { + ggml_backend_sched_reset(sched); - int32_t min = size - 1; - int32_t max = 0; + auto * gf = lctx.graph_init(); - // everything should fit if all seq_ids are smaller than the max - for (uint32_t s = 0; s < n_seqs; ++s) { - const uint32_t n_seq_id = ubatch.n_seq_id[s]; - for (uint32_t j = 0; j < n_seq_id; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; + auto res = build_graph_shift(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - if (seq_id < 0 || (uint32_t) seq_id >= size) { - // too big seq_id - // TODO: would it be possible to resize the cache instead? - LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); - return false; - } - if (j > 0) { - llama_kv_cell & seq = cells[seq_id]; - if (seq.tail >= 0) { - llama_kv_cell & cell = cells[seq.tail]; - // clear cells from seq_ids that become shared - // (should not normally happen, but let's handle it anyway) - cell.seq_id.erase(seq_id); - seq.tail = -1; - if (cell.seq_id.empty()) { - cell.pos = -1; - cell.src = -1; - used -= 1; - } - } - } - } + ggml_backend_sched_alloc_graph(sched, gf); + + res->set_inputs(nullptr); + + lctx.graph_compute(gf, false); + + need_reserve = true; } -#ifndef NDEBUG { - std::vector tails_verif; - tails_verif.assign(size, -1); - for (uint32_t i = 0; i < size; ++i) { - llama_kv_cell & cell = cells[i]; - for (llama_seq_id seq_id : cell.seq_id) { - if (tails_verif[seq_id] != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); - } - tails_verif[seq_id] = i; - } - } + has_shift = false; + for (uint32_t i = 0; i < size; ++i) { - if (tails_verif[i] != cells[i].tail) { - LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); - } + cells[i].delta = 0; } } -#endif + } - // find next empty cell - uint32_t next_empty_cell = head; + if (do_defrag) { + LLAMA_LOG_DEBUG("%s: defragmenting KV cache\n", __func__); - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } + if (defrag_prepare(lctx.graph_max_nodes())) { + ggml_backend_sched_reset(sched); - // find usable cell range - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_seq_id seq_id = ubatch.seq_id[s][0]; - llama_kv_cell & seq_meta = cells[seq_id]; - bool has_cell = false; - if (seq_meta.tail >= 0) { - llama_kv_cell & cell = cells[seq_meta.tail]; - GGML_ASSERT(cell.has_seq_id(seq_id)); - // does this seq_id "own" the cell? - if (cell.seq_id.size() == 1) { has_cell = true; } - } - if (!has_cell) { - llama_kv_cell & empty_cell = cells[next_empty_cell]; - GGML_ASSERT(empty_cell.is_empty()); - // copy old tail into the empty cell - if (seq_meta.tail >= 0) { - llama_kv_cell & orig_cell = cells[seq_meta.tail]; - empty_cell.pos = orig_cell.pos; - empty_cell.src = orig_cell.src; - orig_cell.seq_id.erase(seq_id); - empty_cell.seq_id.insert(seq_id); // will be overwritten - } - seq_meta.tail = next_empty_cell; - // find next empty cell - if (s + 1 < n_seqs) { - next_empty_cell += 1; - for (uint32_t i = 0; i < size; ++i) { - if (next_empty_cell >= size) { next_empty_cell -= size; } - llama_kv_cell & cell = cells[next_empty_cell]; - if (cell.is_empty()) { break; } - next_empty_cell += 1; - } - } - } - if (min > seq_meta.tail) { min = seq_meta.tail; } - if (max < seq_meta.tail) { max = seq_meta.tail; } - } + auto * gf = lctx.graph_init(); - // gather and re-order - for (uint32_t s = 0; s < n_seqs; ++s) { - int32_t dst_id = s + min; - int32_t src_id = cells[ubatch.seq_id[s][0]].tail; - if (dst_id != src_id) { - llama_kv_cell & dst_cell = cells[dst_id]; - llama_kv_cell & src_cell = cells[src_id]; + auto res = build_graph_defrag(lctx.get_cparams(), lctx.get_ctx_compute(), gf); - std::swap(dst_cell.pos, src_cell.pos); - std::swap(dst_cell.src, src_cell.src); - std::swap(dst_cell.seq_id, src_cell.seq_id); + ggml_backend_sched_alloc_graph(sched, gf); - // swap tails (assuming they NEVER overlap) - for (const llama_seq_id seq_id : src_cell.seq_id) { - cells[seq_id].tail = src_id; - } - for (const llama_seq_id seq_id : dst_cell.seq_id) { - cells[seq_id].tail = dst_id; - } - } - } + res->set_inputs(nullptr); - // update the pos of the used seqs - for (uint32_t s = 0; s < n_seqs; ++s) { - const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; - int32_t cell_id = s + min; - llama_kv_cell & cell = cells[cell_id]; + lctx.graph_compute(gf, false); - if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { - // What should happen when the pos backtracks or skips a value? - // Clearing the state mid-batch would require special-casing which isn't done. - LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", - __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); - } - cell.pos = last_pos; - cell.seq_id.clear(); - for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { - const llama_seq_id seq_id = ubatch.seq_id[s][j]; - cell.seq_id.insert(seq_id); - cells[seq_id].tail = cell_id; - } + need_reserve = true; } - // allow getting the range of used cells, from head to head + n - head = min; - n = max - min + 1; - used = std::count_if(cells.begin(), cells.end(), - [](const llama_kv_cell& cell){ return !cell.is_empty(); }); + do_defrag = false; + } + + return need_reserve; +} + +void llama_kv_cache_unified::defrag_sched(float thold) { + // - do not defrag small contexts (i.e. < 2048 tokens) + // - count the padding towards the number of used tokens + const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f; + + // queue defragmentation for next llama_kv_cache_update + if (fragmentation > thold) { + LLAMA_LOG_DEBUG("%s: fragmentation: %.2f - requesting defrag\n", __func__, fragmentation); + + do_defrag = true; + } +} + +void llama_kv_cache_unified::set_full() { + n = size; +} + +llama_sbatch llama_kv_cache_unified::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, true, logits_all); +} + +llama_ubatch llama_kv_cache_unified::ubatch_next( + llama_sbatch & sbatch, + uint32_t n_ubatch, + bool embd_pooled) const { + GGML_UNUSED(embd_pooled); + return sbatch.split_simple(n_ubatch); +} + +bool llama_kv_cache_unified::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; - // sanity check - return n >= n_seqs; + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*ubatch.n_tokens) { + head = 0; } // otherwise, one cell per token. @@ -725,24 +520,50 @@ bool llama_kv_cache_unified::find_slot( pending.ranges.push_back({head, head + n_tokens}); + // a heuristic, to avoid attending the full cache if it is not yet utilized + // after enough generations, the benefit from this heuristic disappears + // if we start defragmenting the cache, the benefit from this will be more important + n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding))); + + //printf("n = %5d, used = %5d, head = %5d\n", n, used, head); + return true; } -uint32_t llama_kv_cache_unified::get_padding(const llama_cparams & cparams) const { - // the FA kernels require padding to avoid extra runtime boundary checks - return cparams.flash_attn ? 256u : 32u; +int32_t llama_kv_cache_unified::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; } -uint32_t llama_kv_cache_unified::cell_max() const { - for (uint32_t i = size; i > 0; --i) { - const llama_kv_cell & cell = cells[i - 1]; +int32_t llama_kv_cache_unified::get_used_cells() const { + return used; +} - if (cell.pos >= 0 && !cell.is_empty()) { - return i; - } +bool llama_kv_cache_unified::get_can_shift() const { + return can_shift; +} + +llama_pos llama_kv_cache_unified::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); } - return 0; + return pos_max; +} + +size_t llama_kv_cache_unified::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; } size_t llama_kv_cache_unified::size_k_bytes() const { @@ -765,68 +586,331 @@ size_t llama_kv_cache_unified::size_v_bytes() const { return size_v_bytes; } -bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { - const uint32_t n_layer = hparams.n_layer; +ggml_tensor * llama_kv_cache_unified::build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const { + const auto & n_ctx_orig = cparams.n_ctx_orig_yarn; - const uint32_t n_kv = cell_max(); - const uint32_t n_used = used; + const auto & yarn_ext_factor = cparams.yarn_ext_factor; + const auto & yarn_beta_fast = cparams.yarn_beta_fast; + const auto & yarn_beta_slow = cparams.yarn_beta_slow; - assert(n_used <= n_kv); + const auto & n_rot = hparams.n_rot; + const auto & rope_type = hparams.rope_type; - //const int64_t t_start = ggml_time_us(); + // See llm_build_deepseek2() for why attn_factor has to be scaled for YaRN RoPE to work correctly. + // See https://github.com/ggerganov/llama.cpp/discussions/7416 for detailed explanation. + const float yarn_attn_factor = model.arch == LLM_ARCH_DEEPSEEK2 ? 1.0f / (1.0f + 0.1f * logf(1.0f / freq_scale)) : cparams.yarn_attn_factor; - // number of cells moved - uint32_t n_moves = 0; + ggml_tensor * tmp; - // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) - // - source view, destination view, copy operation - // - x2 for keys and values - //const uint32_t max_moves = max_nodes()/(6*n_layer); - // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 - const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + if (ggml_is_quantized(cur->type)) { + // dequantize to f32 -> RoPE -> quantize back + tmp = ggml_cast(ctx, cur, GGML_TYPE_F32); - // determine which KV cells to move where - // - // cell i moves to ids[i] - // - // if ids[i] == i || ids[i] == n_kv, then cell i is not moved - // - auto & ids = defrag_info.ids; + tmp = ggml_rope_ext(ctx, tmp, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); - ids.clear(); - ids.resize(n_kv, n_kv); + tmp = ggml_cpy(ctx, tmp, cur); + } else { + // we rotate only the first n_rot dimensions + tmp = ggml_rope_ext_inplace(ctx, cur, + shift, factors, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + yarn_ext_factor, yarn_attn_factor, yarn_beta_fast, yarn_beta_slow); + } - for (uint32_t i0 = 0; i0 < n_used; ++i0) { - const auto & cell0 = cells[i0]; + return tmp; +} - if (!cell0.is_empty()) { - ids[i0] = i0; +class llm_graph_input_k_shift : public llm_graph_input_i { +public: + llm_graph_input_k_shift(const llama_kv_cache_unified * kv_self) : kv_self(kv_self) {} + virtual ~llm_graph_input_k_shift() = default; - continue; - } + void set_input(const llama_ubatch * ubatch) override; - // found a hole - fill it with data from the end of the cache + ggml_tensor * k_shift; // I32 [kv_size] - uint32_t nh = 1; + const llama_kv_cache_unified * kv_self; +}; - // determine the size of the hole - while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { - nh++; +void llm_graph_input_k_shift::set_input(const llama_ubatch * ubatch) { + GGML_UNUSED(ubatch); + + if (k_shift) { + assert(ggml_backend_buffer_is_host(k_shift->buffer)); + + int32_t * data = (int32_t *) k_shift->data; + + for (uint32_t i = 0; i < kv_self->size; ++i) { + data[i] = kv_self->cells[i].delta; } + } +} - uint32_t nf = 0; - uint32_t is = n_kv - 1; +llm_graph_result_ptr llama_kv_cache_unified::build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); - // starting from the end, find nh non-empty cells - for (; is > i0; --is) { - const auto & cell1 = cells[is]; + const auto & n_layer = hparams.n_layer; - if (cell1.is_empty() || ids[is] != n_kv) { - continue; - } + const auto & n_embd_head_k = hparams.n_embd_head_k; + //const auto & n_embd_head_v = hparams.n_embd_head_v; - // non-empty cell which is not yet moved - nf++; + const uint32_t n_ctx_per_seq = cparams.n_ctx / cparams.n_seq_max; + + //GGML_ASSERT(kv_self->size == n_ctx); + + auto inp = std::make_unique(this); + + inp->k_shift = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, cparams.n_ctx); + ggml_set_input(inp->k_shift); + + for (uint32_t il = 0; il < n_layer; ++il) { + const int64_t n_head_kv = hparams.n_head_kv(il); + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + + const bool is_swa = hparams.is_swa(il); + + // note: the swa rope params could become part of the cparams in the future + // if we decide to make them configurable, like the non-sliding ones + const float freq_base_l = is_swa ? hparams.rope_freq_base_train_swa : cparams.rope_freq_base; + const float freq_scale_l = is_swa ? hparams.rope_freq_scale_train_swa : cparams.rope_freq_scale; + + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); + + ggml_tensor * k = + ggml_view_3d(ctx, k_l[il], + n_embd_head_k, n_head_kv, size, + ggml_row_size(k_l[il]->type, n_embd_head_k), + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + 0); + + ggml_tensor * cur = build_rope_shift(cparams, ctx, k, inp->k_shift, rope_factors, freq_base_l, freq_scale_l); + + ggml_build_forward_expand(gf, cur); + } + + res->add_input(std::move(inp)); + + return res; +} + +llm_graph_result_ptr llama_kv_cache_unified::build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const { + auto res = std::make_unique(); + + const auto & ids = defrag_info.ids; + +#if 0 + // CPU defrag + // + // TODO: optimizations are possible: + // - multiple threads + // - avoid copying to the host memory when already there + // + // likely not worth the effort, as we have ggml_graph based defrag + // + + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(); + + const uint32_t kv_size = size; + + std::vector buf_k; + std::vector buf_v; + + for (uint32_t il = 0; il < n_layer; ++il) { + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + const size_t k_size = ggml_row_size(k_l[il]->type, n_embd_k_gqa*kv_size); + + const size_t v_size_el = ggml_type_size(v_l[il]->type); + const size_t v_size = ggml_row_size (v_l[il]->type, n_embd_v_gqa*kv_size); + + buf_k.resize(k_size); + buf_v.resize(v_size); + + ggml_backend_tensor_get(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_get(v_l[il], buf_v.data(), 0, buf_v.size()); + + // batch move [i, i+nm) to [id, id+nm) + // note: cells can move only to a lower index + for (uint32_t i = 0; i < n_kv; ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == n_kv) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < n_kv && ids[i + nm] == id + nm) { + nm++; + } + + // move keys + { + const int64_t os = i*k_size_row; + const int64_t od = id*k_size_row; + + memcpy(buf_k.data() + od, buf_k.data() + os, nm*k_size_row); + } + + // move values (note: they are transposed) + { + const int64_t os = i; + const int64_t od = id; + + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + memcpy(buf_v.data() + (od + j*kv_size)*v_size_el, buf_v.data() + (os + j*kv_size)*v_size_el, nm*v_size_el); + } + } + + i += nm - 1; + } + + ggml_backend_tensor_set(k_l[il], buf_k.data(), 0, buf_k.size()); + ggml_backend_tensor_set(v_l[il], buf_v.data(), 0, buf_v.size()); + } +#else + for (uint32_t i = 0; i < ids.size(); ++i) { + const uint32_t id = ids[i]; + + if (i == id || id == ids.size()) { + continue; + } + + uint32_t nm = 1; + + while (i + nm < ids.size() && ids[i + nm] == id + nm) { + nm++; + } + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { // NOLINT + const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa(il); + const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa(il); + + ggml_tensor * view_k_src = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*i)); + + ggml_tensor * view_k_dst = ggml_view_2d(ctx, k_l[il], + n_embd_k_gqa, nm, + ggml_row_size(k_l[il]->type, n_embd_k_gqa), + ggml_row_size(k_l[il]->type, n_embd_k_gqa*id)); + + ggml_tensor * view_v_src; + ggml_tensor * view_v_dst; + + if (cparams.flash_attn) { + // NOTE: the V cache is not transposed when using flash attention + view_v_src = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + n_embd_v_gqa, nm, + ggml_row_size(v_l[il]->type, n_embd_v_gqa), + ggml_row_size(v_l[il]->type, n_embd_v_gqa*id)); + } else { + view_v_src = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, i)); + + view_v_dst = ggml_view_2d(ctx, v_l[il], + nm, n_embd_v_gqa, + ggml_row_size(v_l[il]->type, size), + ggml_row_size(v_l[il]->type, id)); + } + + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_k_src, view_k_dst)); + ggml_build_forward_expand(gf, ggml_cpy(ctx, view_v_src, view_v_dst)); + } + + i += nm - 1; + } + + //LLAMA_LOG_INFO("gf->n_nodes = %d\n", gf->n_nodes); +#endif + + return res; +} + +bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { + const uint32_t n_layer = hparams.n_layer; + + const uint32_t n_kv = cell_max(); + const uint32_t n_used = used; + + assert(n_used <= n_kv); + + //const int64_t t_start = ggml_time_us(); + + // number of cells moved + uint32_t n_moves = 0; + + // each move requires 6*n_layer tensors (see graph_build_kv_self_defrag) + // - source view, destination view, copy operation + // - x2 for keys and values + //const uint32_t max_moves = max_nodes()/(6*n_layer); + // TODO: tmp fix https://github.com/ggerganov/llama.cpp/issues/6685#issuecomment-2057579516 + const uint32_t max_moves = (n_max_nodes - 2*n_layer)/(6*n_layer); + + // determine which KV cells to move where + // + // cell i moves to ids[i] + // + // if ids[i] == i || ids[i] == n_kv, then cell i is not moved + // + auto & ids = defrag_info.ids; + + ids.clear(); + ids.resize(n_kv, n_kv); + + for (uint32_t i0 = 0; i0 < n_used; ++i0) { + const auto & cell0 = cells[i0]; + + if (!cell0.is_empty()) { + ids[i0] = i0; + + continue; + } + + // found a hole - fill it with data from the end of the cache + + uint32_t nh = 1; + + // determine the size of the hole + while (i0 + nh < n_used && cells[i0 + nh].is_empty()) { + nh++; + } + + uint32_t nf = 0; + uint32_t is = n_kv - 1; + + // starting from the end, find nh non-empty cells + for (; is > i0; --is) { + const auto & cell1 = cells[is]; + + if (cell1.is_empty() || ids[is] != n_kv) { + continue; + } + + // non-empty cell which is not yet moved + nf++; if (nf == nh) { break; @@ -867,7 +951,7 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { cells[i0 + nf] = cell1; // clear the old cell and move the head there - cell1 = llama_kv_cell(); + cell1 = kv_cell(); head = n_used; if (!cont) { @@ -895,13 +979,25 @@ bool llama_kv_cache_unified::defrag_prepare(int32_t n_max_nodes) { return false; } - LLAMA_LOG_DEBUG("(tmp log) KV defrag cell moves: %u\n", n_moves); + LLAMA_LOG_DEBUG("%s: (tmp log) KV defrag cell moves: %u\n", __func__, n_moves); - LLAMA_LOG_DEBUG("expected gf nodes: %u\n", 6*n_moves*n_layer); + LLAMA_LOG_DEBUG("%s: expected gf nodes: %u\n", __func__, 6*n_moves*n_layer); return true; } +uint32_t llama_kv_cache_unified::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + void llama_kv_cache_unified::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { std::vector> cell_ranges; // ranges, from inclusive, to exclusive uint32_t cell_count = 0; @@ -1110,7 +1206,7 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell clear(); for (uint32_t i = 0; i < cell_count; ++i) { - llama_kv_cell & cell = cells[i]; + kv_cell & cell = cells[i]; llama_pos pos; uint32_t n_seq_id; @@ -1133,15 +1229,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell } cell.seq_id.insert(seq_id); - - if (recurrent) { - int32_t & tail = cells[seq_id].tail; - if (tail != -1) { - LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); - return false; - } - tail = i; - } } } @@ -1149,14 +1236,6 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell used = cell_count; } - if (recurrent) { - for (uint32_t i = 0; i < cell_count; ++i) { - uint32_t cell_id = head + i; - // make sure the recurrent states will keep their restored state - cells[cell_id].src = cell_id; - } - } - return true; } @@ -1174,7 +1253,1034 @@ bool llama_kv_cache_unified::state_read_data(llama_io_read_i & io, uint32_t cell LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); return false; } - if (v_trans != (bool) v_trans) { + if (this->v_trans != (bool) v_trans) { + LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); + return false; + } + + // For each layer, read the keys for each cell, one row is one cell, read as one contiguous block + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Read type of key + int32_t k_type_i_ref; + io.read_to(&k_type_i_ref, sizeof(k_type_i_ref)); + const int32_t k_type_i = (int32_t) k_l[il]->type; + if (k_type_i != k_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched key type (%d != %d, layer %d)\n", __func__, k_type_i, k_type_i_ref, il); + return false; + } + + // Read row size of key + uint64_t k_size_row_ref; + io.read_to(&k_size_row_ref, sizeof(k_size_row_ref)); + const size_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + if (k_size_row != k_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched key row size (%zu != %zu, layer %d)\n", __func__, k_size_row, (size_t) k_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the keys for the whole cell range + ggml_backend_tensor_set(k_l[il], io.read(cell_count * k_size_row), head * k_size_row, cell_count * k_size_row); + } + } + + if (!this->v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read row size of value + uint64_t v_size_row_ref; + io.read_to(&v_size_row_ref, sizeof(v_size_row_ref)); + const size_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + if (v_size_row != v_size_row_ref) { + LLAMA_LOG_ERROR("%s: mismatched value row size (%zu != %zu, layer %d)\n", __func__, v_size_row, (size_t) v_size_row_ref, il); + return false; + } + + if (cell_count) { + // Read and set the values for the whole cell range + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_row), head * v_size_row, cell_count * v_size_row); + } + } + } else { + // For each layer, read the values for each cell (transposed) + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Read type of value + int32_t v_type_i_ref; + io.read_to(&v_type_i_ref, sizeof(v_type_i_ref)); + const int32_t v_type_i = (int32_t)v_l[il]->type; + if (v_type_i != v_type_i_ref) { + LLAMA_LOG_ERROR("%s: mismatched value type (%d != %d, layer %d)\n", __func__, v_type_i, v_type_i_ref, il); + return false; + } + + // Read element size of value + uint32_t v_size_el_ref; + io.read_to(&v_size_el_ref, sizeof(v_size_el_ref)); + const size_t v_size_el = ggml_type_size(v_l[il]->type); + if (v_size_el != v_size_el_ref) { + LLAMA_LOG_ERROR("%s: mismatched value element size (%zu != %zu, layer %d)\n", __func__, v_size_el, (size_t) v_size_el_ref, il); + return false; + } + + // Read GQA embedding size + uint32_t n_embd_v_gqa_ref; + io.read_to(&n_embd_v_gqa_ref, sizeof(n_embd_v_gqa_ref)); + if (n_embd_v_gqa != n_embd_v_gqa_ref) { + LLAMA_LOG_ERROR("%s: mismatched GQA embedding size (%u != %u, layer %d)\n", __func__, n_embd_v_gqa, n_embd_v_gqa_ref, il); + return false; + } + + if (cell_count) { + // For each row in the transposed matrix, read the values for the whole cell range + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + const size_t dst_offset = (head + j * size) * v_size_el; + ggml_backend_tensor_set(v_l[il], io.read(cell_count * v_size_el), dst_offset, cell_count * v_size_el); + } + } + } + } + + return true; +} + +// +// llama_kv_cache_recurrent +// + +llama_kv_cache_recurrent::llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size) : hparams(model.hparams) { + const int32_t n_layer = hparams.n_layer; + + LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n", + __func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer); + + head = 0; + size = kv_size; + used = 0; + + this->type_k = type_k; + this->type_v = type_v; + + cells.clear(); + cells.resize(kv_size); + + // create a context for each buffer type + std::map ctx_map; + auto ctx_for_buft = [&](ggml_backend_buffer_type_t buft) -> ggml_context * { + auto it = ctx_map.find(buft); + if (it == ctx_map.end()) { + ggml_init_params params = { + /*.mem_size =*/ size_t(2u*n_layer*ggml_tensor_overhead()), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ctx = ggml_init(params); + if (!ctx) { + return nullptr; + } + + ctx_map[buft] = ctx; + ctxs.emplace_back(ctx); + + return ctx; + } + + return it->second; + }; + + k_l.reserve(n_layer); + v_l.reserve(n_layer); + + for (int i = 0; i < n_layer; i++) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); + + const char * dev_name = "CPU"; + + ggml_backend_buffer_type_t buft = ggml_backend_cpu_buffer_type(); + + if (offload) { + auto * dev = model.dev_layer(i); + buft = ggml_backend_dev_buffer_type(dev); + + dev_name = ggml_backend_dev_name(dev); + } + + LLAMA_LOG_DEBUG("%s, layer %3d: dev = %s\n", __func__, i, dev_name); + + ggml_context * ctx = ctx_for_buft(buft); + if (!ctx) { + throw std::runtime_error("failed to create ggml context for kv cache"); + } + + ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); + ggml_tensor * v = ggml_new_tensor_1d(ctx, type_v, n_embd_v_gqa*kv_size); + ggml_format_name(k, "cache_k_l%d", i); + ggml_format_name(v, "cache_v_l%d", i); + k_l.push_back(k); + v_l.push_back(v); + } + + // allocate tensors and initialize the buffers to avoid NaNs in the padding + for (auto it : ctx_map) { + auto * buft = it.first; + auto * ctx = it.second; + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors_from_buft(ctx, buft); + if (!buf) { + throw std::runtime_error("failed to allocate buffer for kv cache"); + } + ggml_backend_buffer_clear(buf, 0); + LLAMA_LOG_INFO("%s: %10s KV buffer size = %8.2f MiB\n", __func__, ggml_backend_buffer_name(buf), ggml_backend_buffer_get_size(buf)/1024.0/1024.0); + bufs.emplace_back(buf); + } + + { + const size_t memory_size_k = size_k_bytes(); + const size_t memory_size_v = size_v_bytes(); + + LLAMA_LOG_INFO("%s: KV self size = %7.2f MiB, K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__, + (float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), + ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f), + ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f)); + } +} + +void llama_kv_cache_recurrent::clear() { + for (int32_t i = 0; i < (int32_t) size; ++i) { + cells[i].pos = -1; + cells[i].seq_id.clear(); + cells[i].src = -1; + cells[i].tail = -1; + } + head = 0; + used = 0; + + for (auto & buf : bufs) { + ggml_backend_buffer_clear(buf.get(), 0); + } +} + +bool llama_kv_cache_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = size; + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // models like Mamba or RWKV can't have a state partially erased + if (seq_id >= (int64_t) size) { + // could be fatal + return false; + } + if (0 <= seq_id) { + int32_t & tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + const kv_cell & cell = cells[tail_id]; + // partial intersection is invalid + if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + return false; + } + // invalidate tails which will be cleared + if (p0 <= cell.pos && cell.pos < p1) { + tail_id = -1; + } + } + } else { + // seq_id is negative, then the range should include everything or nothing + if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + return false; + } + } + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].pos >= p0 && cells[i].pos < p1) { + if (seq_id < 0) { + cells[i].seq_id.clear(); + } else if (cells[i].has_seq_id(seq_id)) { + cells[i].seq_id.erase(seq_id); + } else { + continue; + } + if (cells[i].is_empty()) { + // keep count of the number of used cells + if (cells[i].pos >= 0) { + used--; + } + cells[i].pos = -1; + cells[i].src = -1; + if (new_head == size) { + new_head = i; + } + } + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } + + return true; +} + +void llama_kv_cache_recurrent::seq_cp(llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) { + if (seq_id_src == seq_id_dst) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + if ((uint32_t) seq_id_dst < size && (uint32_t) seq_id_src < size) { + kv_cell & tail_src = cells[seq_id_src]; + kv_cell & tail_dst = cells[seq_id_dst]; + if (tail_dst.tail >= 0) { + // clear destination seq_id if it wasn't empty + kv_cell & cell_dst = cells[tail_dst.tail]; + + cell_dst.seq_id.erase(seq_id_dst); + tail_dst.tail = -1; + if (cell_dst.seq_id.empty()) { + cell_dst.pos = -1; + cell_dst.src = -1; + used -= 1; + } + } + if (tail_src.tail >= 0) { + kv_cell & cell_src = cells[tail_src.tail]; + + cell_src.seq_id.insert(seq_id_dst); + tail_dst.tail = tail_src.tail; + } + } +} + +void llama_kv_cache_recurrent::seq_keep(llama_seq_id seq_id) { + uint32_t new_head = size; + + for (uint32_t i = 0; i < size; ++i) { + if ((llama_seq_id) i != seq_id) { + cells[i].tail = -1; + } + + if (!cells[i].has_seq_id(seq_id)) { + if (cells[i].pos >= 0) { + used--; + } + + cells[i].pos = -1; + cells[i].src = -1; + cells[i].seq_id.clear(); + + if (new_head == size){ + new_head = i; + } + } else { + cells[i].seq_id.clear(); + cells[i].seq_id.insert(seq_id); + } + } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != size && new_head < head) { + head = new_head; + } +} + +void llama_kv_cache_recurrent::seq_add(llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) { + if (delta == 0) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be shifted + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos += delta; + } + } + } +} + +void llama_kv_cache_recurrent::seq_div(llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) { + if (d == 1) { + return; + } + + if (p0 < 0) { + p0 = 0; + } + + if (p1 < 0) { + p1 = std::numeric_limits::max(); + } + + // If there is no range then return early to avoid looping over the cache. + if (p0 == p1) { + return; + } + + // for Mamba-like or RWKV models, only the pos needs to be changed + if (0 <= seq_id && seq_id < (int64_t) size) { + const int32_t tail_id = cells[seq_id].tail; + if (tail_id >= 0) { + kv_cell & cell = cells[tail_id]; + if (cell.has_seq_id(seq_id) && p0 <= cell.pos && cell.pos < p1) { + cell.pos /= d; + } + } + } +} + +llama_pos llama_kv_cache_recurrent::seq_pos_max(llama_seq_id seq_id) const { + llama_pos result = 0; + + for (uint32_t i = 0; i < size; ++i) { + if (cells[i].has_seq_id(seq_id)) { + result = std::max(result, cells[i].pos); + } + } + + return result; +} + +void llama_kv_cache_recurrent::restore() { + if (pending.ranges.empty()) { + return; + } + + seq_rm(-1, -1, -1); +} + +void llama_kv_cache_recurrent::commit() { + pending.ranges.clear(); +} + +bool llama_kv_cache_recurrent::update(llama_context & lctx) { + GGML_UNUSED(lctx); + return false; +} + +void llama_kv_cache_recurrent::defrag_sched(float thold) { + GGML_UNUSED(thold); + // noop +} + +void llama_kv_cache_recurrent::set_full() { + n = size; +} + +llama_sbatch llama_kv_cache_recurrent::sbatch_init( + const llama_batch & batch, + bool logits_all) { + return llama_sbatch(batch, hparams.n_embd, false, logits_all); +} + +llama_ubatch llama_kv_cache_recurrent::ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const { + if (embd_pooled) { + // Pooled embeddings cannot be split across ubatches (yet) + return sbatch.split_seq(n_ubatch); + } + + return sbatch.split_equal(n_ubatch); +} + +bool llama_kv_cache_recurrent::find_slot( + const llama_ubatch & ubatch) { + const uint32_t n_tokens = ubatch.n_tokens; + const uint32_t n_seqs = ubatch.n_seqs; + + const uint32_t n_seq_tokens = ubatch.n_seq_tokens; + + // if we have enough unused cells before the current head -> + // better to start searching from the beginning of the cache, hoping to fill it + if (head > used + 2*n_tokens) { + head = 0; + } + + // For recurrent state architectures (like Mamba or RWKV), + // each cache cell can store the state for a whole sequence. + // A slot should be always be contiguous. + + // can only process batches with an equal number of new tokens in each sequence + GGML_ASSERT(ubatch.equal_seqs); + + int32_t min = size - 1; + int32_t max = 0; + + // everything should fit if all seq_ids are smaller than the max + for (uint32_t s = 0; s < n_seqs; ++s) { + const uint32_t n_seq_id = ubatch.n_seq_id[s]; + for (uint32_t j = 0; j < n_seq_id; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + + if (seq_id < 0 || (uint32_t) seq_id >= size) { + // too big seq_id + // TODO: would it be possible to resize the cache instead? + LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size); + return false; + } + if (j > 0) { + kv_cell & seq = cells[seq_id]; + if (seq.tail >= 0) { + kv_cell & cell = cells[seq.tail]; + // clear cells from seq_ids that become shared + // (should not normally happen, but let's handle it anyway) + cell.seq_id.erase(seq_id); + seq.tail = -1; + if (cell.seq_id.empty()) { + cell.pos = -1; + cell.src = -1; + used -= 1; + } + } + } + } + } + +#ifndef NDEBUG + { + std::vector tails_verif; + tails_verif.assign(size, -1); + for (uint32_t i = 0; i < size; ++i) { + kv_cell & cell = cells[i]; + for (llama_seq_id seq_id : cell.seq_id) { + if (tails_verif[seq_id] != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tails_verif[seq_id]); + } + tails_verif[seq_id] = i; + } + } + for (uint32_t i = 0; i < size; ++i) { + if (tails_verif[i] != cells[i].tail) { + LLAMA_LOG_ERROR("%s: wrong tail for seq_id %d, (%d instead of %d)\n", __func__, i, cells[i].tail, tails_verif[i]); + } + } + } +#endif + + // find next empty cell + uint32_t next_empty_cell = head; + + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + + // find usable cell range + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + kv_cell & seq_meta = cells[seq_id]; + bool has_cell = false; + if (seq_meta.tail >= 0) { + kv_cell & cell = cells[seq_meta.tail]; + GGML_ASSERT(cell.has_seq_id(seq_id)); + // does this seq_id "own" the cell? + if (cell.seq_id.size() == 1) { has_cell = true; } + } + if (!has_cell) { + kv_cell & empty_cell = cells[next_empty_cell]; + GGML_ASSERT(empty_cell.is_empty()); + // copy old tail into the empty cell + if (seq_meta.tail >= 0) { + kv_cell & orig_cell = cells[seq_meta.tail]; + empty_cell.pos = orig_cell.pos; + empty_cell.src = orig_cell.src; + orig_cell.seq_id.erase(seq_id); + empty_cell.seq_id.insert(seq_id); // will be overwritten + } + seq_meta.tail = next_empty_cell; + // find next empty cell + if (s + 1 < n_seqs) { + next_empty_cell += 1; + for (uint32_t i = 0; i < size; ++i) { + if (next_empty_cell >= size) { next_empty_cell -= size; } + kv_cell & cell = cells[next_empty_cell]; + if (cell.is_empty()) { break; } + next_empty_cell += 1; + } + } + } + if (min > seq_meta.tail) { min = seq_meta.tail; } + if (max < seq_meta.tail) { max = seq_meta.tail; } + } + + // gather and re-order + for (uint32_t s = 0; s < n_seqs; ++s) { + int32_t dst_id = s + min; + int32_t src_id = cells[ubatch.seq_id[s][0]].tail; + if (dst_id != src_id) { + kv_cell & dst_cell = cells[dst_id]; + kv_cell & src_cell = cells[src_id]; + + std::swap(dst_cell.pos, src_cell.pos); + std::swap(dst_cell.src, src_cell.src); + std::swap(dst_cell.seq_id, src_cell.seq_id); + + // swap tails (assuming they NEVER overlap) + for (const llama_seq_id seq_id : src_cell.seq_id) { + cells[seq_id].tail = src_id; + } + for (const llama_seq_id seq_id : dst_cell.seq_id) { + cells[seq_id].tail = dst_id; + } + } + } + + // update the pos of the used seqs + for (uint32_t s = 0; s < n_seqs; ++s) { + const llama_pos last_pos = ubatch.pos[n_seq_tokens * s + n_seq_tokens - 1]; + int32_t cell_id = s + min; + kv_cell & cell = cells[cell_id]; + + if (cell.pos >= 0 && last_pos != cell.pos + (llama_pos) n_seq_tokens) { + // What should happen when the pos backtracks or skips a value? + // Clearing the state mid-batch would require special-casing which isn't done. + LLAMA_LOG_WARN("%s: non-consecutive token position %d after %d for sequence %d with %u new tokens\n", + __func__, last_pos, cell.pos, ubatch.seq_id[s][0], n_seq_tokens); + } + cell.pos = last_pos; + cell.seq_id.clear(); + for (int32_t j = 0; j < ubatch.n_seq_id[s]; ++j) { + const llama_seq_id seq_id = ubatch.seq_id[s][j]; + cell.seq_id.insert(seq_id); + cells[seq_id].tail = cell_id; + } + } + + // allow getting the range of used cells, from head to head + n + head = min; + n = max - min + 1; + used = std::count_if(cells.begin(), cells.end(), + [](const kv_cell & cell){ return !cell.is_empty(); }); + + // sanity check + return n >= n_seqs; +} + +int32_t llama_kv_cache_recurrent::get_n_tokens() const { + int32_t result = 0; + + for (uint32_t i = 0; i < size; i++) { + result += cells[i].seq_id.size(); + } + + return result; +} + +int32_t llama_kv_cache_recurrent::get_used_cells() const { + return used; +} + +llama_pos llama_kv_cache_recurrent::get_pos_max() const { + llama_pos pos_max = -1; + for (const auto & cell : cells) { + pos_max = std::max(pos_max, cell.pos); + } + + return pos_max; +} + +bool llama_kv_cache_recurrent::get_can_shift() const { + return false; +} + +int32_t llama_kv_cache_recurrent::s_copy(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + // prevent out-of-bound sources + if (cell.src < 0 || (uint32_t) cell.src >= size) { + cell.src = cell_id; + } + + int32_t res = cell.src; + + // TODO: do not mutate the KV cache + // ensure copy only happens once + if (cell.src != (int32_t) cell_id) { + cell.src = cell_id; + } + + return res; +} + +float llama_kv_cache_recurrent::s_mask(int i) const { + const uint32_t cell_id = i + head; + + ////////////////////////////////////////////// + // TODO: this should not mutate the KV cache ! + kv_cell & cell = const_cast(cells[cell_id]); + + float res = (float) (cell.src >= 0); + + // only clear once + if (cell.src < 0) { + cell.src = cell_id; + } + + return res; +} + +uint32_t llama_kv_cache_recurrent::cell_max() const { + for (uint32_t i = size; i > 0; --i) { + const kv_cell & cell = cells[i - 1]; + + if (cell.pos >= 0 && !cell.is_empty()) { + return i; + } + } + + return 0; +} + +size_t llama_kv_cache_recurrent::total_size() const { + size_t size = 0; + for (const auto & buf : bufs) { + size += ggml_backend_buffer_get_size(buf.get()); + } + + return size; +} + +size_t llama_kv_cache_recurrent::size_k_bytes() const { + size_t size_k_bytes = 0; + + for (const auto & k : k_l) { + size_k_bytes += ggml_nbytes(k); + } + + return size_k_bytes; +} + +size_t llama_kv_cache_recurrent::size_v_bytes() const { + size_t size_v_bytes = 0; + + for (const auto & v : v_l) { + size_v_bytes += ggml_nbytes(v); + } + + return size_v_bytes; +} + +void llama_kv_cache_recurrent::state_write(llama_io_write_i & io, llama_seq_id seq_id) const { + std::vector> cell_ranges; // ranges, from inclusive, to exclusive + uint32_t cell_count = 0; + + // Count the number of cells with the specified seq_id + // Find all the ranges of cells with this seq id (or all, when -1) + uint32_t cell_range_begin = size; + for (uint32_t i = 0; i < size; ++i) { + const auto & cell = cells[i]; + if ((seq_id == -1 && !cell.is_empty()) || cell.has_seq_id(seq_id)) { + ++cell_count; + if (cell_range_begin == size) { + cell_range_begin = i; + } + } else { + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, i); + cell_range_begin = size; + } + } + } + if (cell_range_begin != size) { + cell_ranges.emplace_back(cell_range_begin, size); + } + + // DEBUG CHECK: Sum of cell counts in ranges should equal the total cell count + uint32_t cell_count_check = 0; + for (const auto & range : cell_ranges) { + cell_count_check += range.second - range.first; + } + GGML_ASSERT(cell_count == cell_count_check); + + io.write(&cell_count, sizeof(cell_count)); + + state_write_meta(io, cell_ranges, seq_id); + state_write_data(io, cell_ranges); +} + +void llama_kv_cache_recurrent::state_read(llama_io_read_i & io, llama_seq_id seq_id) { + uint32_t cell_count; + io.read_to(&cell_count, sizeof(cell_count)); + + bool res = true; + res = res && state_read_meta(io, cell_count, seq_id); + res = res && state_read_data(io, cell_count); + + if (!res) { + if (seq_id == -1) { + clear(); + } else { + seq_rm(seq_id, -1, -1); + } + throw std::runtime_error("failed to restore kv cache"); + } +} + +void llama_kv_cache_recurrent::state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id) const { + for (const auto & range : cell_ranges) { + for (uint32_t i = range.first; i < range.second; ++i) { + const auto & cell = cells[i]; + const llama_pos pos = cell.pos; + const uint32_t n_seq_id = seq_id == -1 ? cell.seq_id.size() : 0; + + io.write(&pos, sizeof(pos)); + io.write(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id) { + for (auto seq_id : cell.seq_id) { + io.write(&seq_id, sizeof(seq_id)); + } + } + } + } +} + +void llama_kv_cache_recurrent::state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const { + const uint32_t v_trans = 0; + const uint32_t n_layer = hparams.n_layer; + + io.write(&v_trans, sizeof(v_trans)); + io.write(&n_layer, sizeof(n_layer)); + + std::vector tmp_buf; + + // Iterate and write all the keys first, each row is a cell + // Get whole range at a time + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(il) + hparams.n_embd_k_s(); + + // Write key type + const int32_t k_type_i = (int32_t)k_l[il]->type; + io.write(&k_type_i, sizeof(k_type_i)); + + // Write row size of key + const uint64_t k_size_row = ggml_row_size(k_l[il]->type, n_embd_k_gqa); + io.write(&k_size_row, sizeof(k_size_row)); + + // Read each range of cells of k_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * k_size_row; + io.write_tensor(k_l[il], range.first * k_size_row, buf_size); + } + } + + if (!v_trans) { + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write row size of value + const uint64_t v_size_row = ggml_row_size(v_l[il]->type, n_embd_v_gqa); + io.write(&v_size_row, sizeof(v_size_row)); + + // Read each range of cells of v_size length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t buf_size = range_size * v_size_row; + io.write_tensor(v_l[il], range.first * v_size_row, buf_size); + } + } + } else { + // When v is transposed, we also need the element size and get the element ranges from each row + const uint32_t kv_size = size; + for (uint32_t il = 0; il < n_layer; ++il) { + const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(il) + hparams.n_embd_v_s(); + + // Write value type + const int32_t v_type_i = (int32_t)v_l[il]->type; + io.write(&v_type_i, sizeof(v_type_i)); + + // Write element size + const uint32_t v_size_el = ggml_type_size(v_l[il]->type); + io.write(&v_size_el, sizeof(v_size_el)); + + // Write GQA embedding size + io.write(&n_embd_v_gqa, sizeof(n_embd_v_gqa)); + + // For each row, we get the element values of each cell + for (uint32_t j = 0; j < n_embd_v_gqa; ++j) { + // Read each range of cells of v_size_el length each into tmp_buf and write out + for (const auto & range : cell_ranges) { + const size_t range_size = range.second - range.first; + const size_t src_offset = (range.first + j * kv_size) * v_size_el; + const size_t buf_size = range_size * v_size_el; + io.write_tensor(v_l[il], src_offset, buf_size); + } + } + } + } +} + +bool llama_kv_cache_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { + if (dest_seq_id != -1) { + // single sequence + + seq_rm(dest_seq_id, -1, -1); + + llama_sbatch sbatch; + llama_ubatch batch = sbatch.reserve_ubatch(cell_count, /* has_embd */ false); + + batch.n_tokens = cell_count; + batch.n_seq_tokens = cell_count; + batch.n_seqs = 1; + + for (uint32_t i = 0; i < cell_count; ++i) { + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + if (n_seq_id != 0) { + LLAMA_LOG_ERROR("%s: invalid seq_id-agnostic kv cell\n", __func__); + return false; + } + + batch.pos[i] = pos; + } + batch.n_seq_id[0] = 1; + batch.seq_id[0] = &dest_seq_id; + if (!find_slot(batch)) { + LLAMA_LOG_ERROR("%s: failed to find available cells in kv cache\n", __func__); + return false; + } + commit(); + + // DEBUG CHECK: kv.head should be our first cell, kv.head + cell_count - 1 should be our last cell (verify seq_id and pos values) + // Assume that this is one contiguous block of cells + GGML_ASSERT(head + cell_count <= size); + GGML_ASSERT(cells[head].pos == batch.pos[0]); + GGML_ASSERT(cells[head + cell_count - 1].pos == batch.pos[cell_count - 1]); + GGML_ASSERT(cells[head].has_seq_id(dest_seq_id)); + GGML_ASSERT(cells[head + cell_count - 1].has_seq_id(dest_seq_id)); + } else { + // whole KV cache restore + + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache\n", __func__); + return false; + } + + clear(); + + for (uint32_t i = 0; i < cell_count; ++i) { + kv_cell & cell = cells[i]; + + llama_pos pos; + uint32_t n_seq_id; + + io.read_to(&pos, sizeof(pos)); + io.read_to(&n_seq_id, sizeof(n_seq_id)); + + cell.pos = pos; + + for (uint32_t j = 0; j < n_seq_id; ++j) { + llama_seq_id seq_id; + io.read_to(&seq_id, sizeof(seq_id)); + + // TODO: llama_kv_cache_recurrent should have a notion of max sequences + //if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) { + if (seq_id < 0) { + //LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx)); + LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id); + return false; + } + + cell.seq_id.insert(seq_id); + + int32_t & tail = cells[seq_id].tail; + if (tail != -1) { + LLAMA_LOG_ERROR("%s: duplicate tail for seq_id %d in cell %d and %d\n", __func__, seq_id, i, tail); + return false; + } + tail = i; + } + } + + head = 0; + used = cell_count; + } + + for (uint32_t i = 0; i < cell_count; ++i) { + uint32_t cell_id = head + i; + // make sure the recurrent states will keep their restored state + cells[cell_id].src = cell_id; + } + + return true; +} + +bool llama_kv_cache_recurrent::state_read_data(llama_io_read_i & io, uint32_t cell_count) { + uint32_t v_trans; + uint32_t n_layer; + io.read_to(&v_trans, sizeof(v_trans)); + io.read_to(&n_layer, sizeof(n_layer)); + + if (n_layer != hparams.n_layer) { + LLAMA_LOG_ERROR("%s: mismatched layer count (%u instead of %u)\n", __func__, n_layer, hparams.n_layer); + return false; + } + if (cell_count > size) { + LLAMA_LOG_ERROR("%s: not enough cells in kv cache to restore state (%u > %u)\n", __func__, cell_count, size); + return false; + } + if (false != (bool) v_trans) { LLAMA_LOG_ERROR("%s: incompatible V transposition\n", __func__); return false; } @@ -1326,7 +2432,7 @@ void llama_kv_cache_view_update(llama_kv_cache_view * view, const llama_kv_cache view->cells_sequences = (llama_seq_id *)p; } - const std::vector & kv_cells = kvu->cells; + const std::vector & kv_cells = kvu->cells; llama_kv_cache_view_cell * c_curr = view->cells; llama_seq_id * cs_curr = view->cells_sequences; int32_t used_cells = 0; diff --git a/src/llama-kv-cache.h b/src/llama-kv-cache.h index 56c74035ae1b9..bf3b4b6a4430f 100644 --- a/src/llama-kv-cache.h +++ b/src/llama-kv-cache.h @@ -2,32 +2,72 @@ #include "llama.h" #include "llama-io.h" +#include "llama-graph.h" #include "llama-memory.h" #include "ggml-cpp.h" -#include #include #include struct llama_cparams; struct llama_hparams; struct llama_ubatch; +struct llama_sbatch; +struct llama_model; +struct llama_context; struct llama_kv_cache : public llama_memory_i { - using llama_memory_i::llama_memory_i; + virtual ~llama_kv_cache() = default; - virtual void restore() = 0; // call if batch processing fails - restores the cache state - virtual void commit() = 0; // call after successful batch processing - clears any pending state + // call if batch processing fails - restores the cache state + virtual void restore() = 0; - virtual int32_t get_n_tokens() const = 0; - virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + // call after successful batch processing - clears any pending state + virtual void commit() = 0; - virtual bool get_can_shift() const = 0; + // process any pending defrag/shift/etc. operations + // optionally call once before processing a new batch + virtual bool update(llama_context & lctx) = 0; + + // schedule a defrag if the fragmentation threshold is exceeded. otherwise, do nothing + virtual void defrag_sched(float thold) = 0; + + // simulate full cache, used for allocating worst-case compute buffers + virtual void set_full() = 0; + + // + // batch processing + // + + virtual llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) = 0; + + // different KV caches require different batch splitting strategies + virtual llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const = 0; + + // find an empty slot of size "n_tokens" in the cache + virtual bool find_slot(const llama_ubatch & batch) = 0; + + // getters + virtual int32_t get_n_tokens() const = 0; + virtual int32_t get_used_cells() const = 0; // TODO: remove, this is too-specific to the unified cache + virtual llama_pos get_pos_max() const = 0; + virtual bool get_can_shift() const = 0; bool get_can_edit() const override { return get_can_shift(); } + + // + // state write/read + // + + virtual void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const = 0; + virtual void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) = 0; }; +// +// llama_kv_cache_guard +// + struct llama_kv_cache_guard { llama_kv_cache_guard(llama_kv_cache * kv) : kv(kv) {} @@ -43,65 +83,50 @@ struct llama_kv_cache_guard { llama_kv_cache * kv; }; -struct llama_kv_cell { - llama_pos pos = -1; - llama_pos delta = 0; - int32_t src = -1; // used by recurrent state models to copy states - int32_t tail = -1; +// +// llama_kv_cache_unified +// - std::set seq_id; +// TODO: add notion of max sequences +class llama_kv_cache_unified : public llama_kv_cache { +public: + struct kv_cell { + llama_pos pos = -1; + llama_pos delta = 0; - bool has_seq_id(const llama_seq_id & id) const { - return seq_id.find(id) != seq_id.end(); - } + std::set seq_id; - bool is_empty() const { - return seq_id.empty(); - } + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } - bool is_same_seq(const llama_kv_cell & other) const { - return seq_id == other.seq_id; - } -}; + bool is_empty() const { + return seq_id.empty(); + } -// ring-buffer of cached KV data -// TODO: pimpl -// TODO: add notion of max sequences -class llama_kv_cache_unified : public llama_kv_cache { -public: - // can be used to query data from the model if needed - struct callbacks { - std::function get_rope_factors; + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } }; - llama_kv_cache_unified( - const llama_hparams & hparams, - callbacks cbs); - - virtual ~llama_kv_cache_unified() = default; + static uint32_t get_padding(const llama_cparams & cparams); - // TODO: become constructor - bool init( - const llama_model & model, // TODO: do not reference the model - const llama_cparams & cparams, + llama_kv_cache_unified( + const llama_model & model, ggml_type type_k, ggml_type type_v, + bool v_trans, + bool offload, uint32_t kv_size, - bool offload); - - int32_t get_n_tokens() const override; - int32_t get_used_cells() const override; + uint32_t padding); - size_t total_size() const; + ~llama_kv_cache_unified() = default; - // TODO: better data structures to reduce the cost of this operation - llama_pos pos_max() const; + // + // llama_memory_i + // void clear() override; - void defrag() override; - - virtual void restore() override; - virtual void commit() override; bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; @@ -111,25 +136,76 @@ class llama_kv_cache_unified : public llama_kv_cache { llama_pos seq_pos_max(llama_seq_id seq_id) const override; - bool get_can_shift() const override; + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & ctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; - // find an empty slot of size "n_tokens" in the cache // updates the cache head // Note: On success, it's important that cache.head points // to the first cell of the slot. - bool find_slot(const llama_ubatch & batch); + bool find_slot(const llama_ubatch & batch) override; - // TODO: maybe not needed - uint32_t get_padding(const llama_cparams & cparams) const; + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; - // find how many cells are currently in use - uint32_t cell_max() const; + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; - size_t size_k_bytes() const; - size_t size_v_bytes() const; + bool get_can_shift() const override; - // defrag + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; + + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_impl also uses it, so it + // cannot be freely changed after a slot has been allocated. + uint32_t head = 0; + uint32_t size = 0; + uint32_t used = 0; // used cells (i.e. at least one seq_id) + + // computed before each graph build + uint32_t n = 0; + + std::vector cells; + + std::vector k_l; // per layer + std::vector v_l; + +private: + const llama_model & model; + const llama_hparams & hparams; + + bool has_shift = false; + bool do_defrag = false; + + bool v_trans = true; // the value tensor is transposed + bool can_shift = false; + + // required padding + uint32_t padding = 1; + + ggml_type type_k = GGML_TYPE_F16; + ggml_type type_v = GGML_TYPE_F16; + + std::vector ctxs; + std::vector bufs; + // defrag struct { std::vector ids; } defrag_info; @@ -138,7 +214,6 @@ class llama_kv_cache_unified : public llama_kv_cache { bool defrag_prepare(int32_t n_max_nodes); // commit/restore cache - struct slot_range { uint32_t c0 = 0; // note: these are cell indices, not sequence positions uint32_t c1 = 0; @@ -149,25 +224,124 @@ class llama_kv_cache_unified : public llama_kv_cache { std::vector ranges; } pending; - // state write/load + // find how many cells are currently in use + uint32_t cell_max() const; - void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const; - void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1); + size_t total_size() const; - // members + size_t size_k_bytes() const; + size_t size_v_bytes() const; - const llama_hparams & hparams; + ggml_tensor * build_rope_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_tensor * cur, + ggml_tensor * shift, + ggml_tensor * factors, + float freq_base, + float freq_scale) const; + + llm_graph_result_ptr build_graph_shift( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; + + llm_graph_result_ptr build_graph_defrag( + const llama_cparams & cparams, + ggml_context * ctx, + ggml_cgraph * gf) const; - callbacks cbs; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; + void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; - bool has_shift = false; - bool do_defrag = false; + bool state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id = -1); + bool state_read_data(llama_io_read_i & io, uint32_t cell_count); +}; - // TODO: remove this and implement llama_kv_cache_recurrent instead - bool recurrent = false; // with recurrent state models, a cell can hold the state for more than one past token +// +// llama_kv_cache_recurrent +// - bool v_trans = true; // the value tensor is transposed - bool can_shift = false; +class llama_kv_cache_recurrent : public llama_kv_cache { +public: + struct kv_cell { + llama_pos pos = -1; + int32_t src = -1; // used to copy states + int32_t tail = -1; + + std::set seq_id; + + bool has_seq_id(const llama_seq_id & id) const { + return seq_id.find(id) != seq_id.end(); + } + + bool is_empty() const { + return seq_id.empty(); + } + + bool is_same_seq(const kv_cell & other) const { + return seq_id == other.seq_id; + } + }; + + llama_kv_cache_recurrent( + const llama_model & model, + ggml_type type_k, + ggml_type type_v, + bool offload, + uint32_t kv_size); + + ~llama_kv_cache_recurrent() = default; + + // + // llama_memory_i + // + + void clear() override; + + bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) override; + void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) override; + void seq_keep(llama_seq_id seq_id) override; + void seq_add (llama_seq_id seq_id, llama_pos p0, llama_pos p1, llama_pos delta) override; + void seq_div (llama_seq_id seq_id, llama_pos p0, llama_pos p1, int d) override; + + llama_pos seq_pos_max(llama_seq_id seq_id) const override; + + // + // llama_kv_cache + // + + void restore() override; + void commit() override; + + bool update(llama_context & lctx) override; + + void defrag_sched(float thold) override; + + void set_full() override; + + llama_sbatch sbatch_init(const llama_batch & batch, bool logits_all) override; + + llama_ubatch ubatch_next(llama_sbatch & sbatch, uint32_t n_ubatch, bool embd_pooled) const override; + + bool find_slot(const llama_ubatch & batch) override; + + int32_t get_n_tokens() const override; + int32_t get_used_cells() const override; + + // TODO: better data structures to reduce the cost of this operation + llama_pos get_pos_max() const override; + + bool get_can_shift() const override; + + // TODO: temporary methods - they are not really const as they do const_cast<>, fix this + int32_t s_copy(int i) const; + float s_mask(int i) const; + + // state write/load + + void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1) const override; + void state_read (llama_io_read_i & io, llama_seq_id seq_id = -1) override; // Note: The value of head isn't only used to optimize searching // for a free KV slot. llama_decode_impl also uses it, so it @@ -179,18 +353,41 @@ class llama_kv_cache_unified : public llama_kv_cache { // computed before each graph build uint32_t n = 0; - std::vector cells; + std::vector cells; std::vector k_l; // per layer std::vector v_l; private: + //const llama_model & model; + const llama_hparams & hparams; + + // commit/restore cache + // TODO: rework for recurrent cache + struct slot_range { + uint32_t c0 = 0; // note: these are cell indices, not sequence positions + uint32_t c1 = 0; + }; + + // pending cell updates that are not yet committed + struct { + std::vector ranges; + } pending; + ggml_type type_k = GGML_TYPE_F16; ggml_type type_v = GGML_TYPE_F16; std::vector ctxs; std::vector bufs; + // find how many cells are currently in use + uint32_t cell_max() const; + + size_t total_size() const; + + size_t size_k_bytes() const; + size_t size_v_bytes() const; + void state_write_meta(llama_io_write_i & io, const std::vector> & cell_ranges, llama_seq_id seq_id = -1) const; void state_write_data(llama_io_write_i & io, const std::vector> & cell_ranges) const; @@ -198,11 +395,6 @@ class llama_kv_cache_unified : public llama_kv_cache { bool state_read_data(llama_io_read_i & io, uint32_t cell_count); }; -// TODO: temporary reusing llama_kv_cache_unified -- implement recurrent cache and simplify llama_kv_cache_unified -//class llama_kv_cache_recurrent : public llama_kv_cache_unified { -//public: -// using llama_kv_cache_unified::llama_kv_cache_unified; -//}; // // kv cache view diff --git a/src/llama-memory.h b/src/llama-memory.h index dfa8c4e90fc2a..c7412d5911ed7 100644 --- a/src/llama-memory.h +++ b/src/llama-memory.h @@ -2,12 +2,22 @@ #include "llama.h" +struct llama_memory_params { + // kv cache + ggml_type type_k; + ggml_type type_v; + + // parameters for other types of memory + // ... +}; + // general concept of LLM memory // the KV cache is a type of LLM memory, but there can be other types class llama_memory_i { public: + virtual ~llama_memory_i() = default; + virtual void clear() = 0; - virtual void defrag() = 0; virtual bool seq_rm (llama_seq_id seq_id, llama_pos p0, llama_pos p1) = 0; virtual void seq_cp (llama_seq_id seq_id_src, llama_seq_id seq_id_dst, llama_pos p0, llama_pos p1) = 0; diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 51092a128c5c6..e69601949ed51 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4445,6 +4445,19 @@ const ggml_tensor * llama_model::get_tensor(const char * name) const { return it->second; } +ggml_tensor * llama_model::get_rope_factors(uint32_t n_ctx_per_seq, int il) const { + // choose long/short freq factors based on the context size + if (layers[il].rope_freqs != nullptr) { + return layers[il].rope_freqs; + } + + if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { + return layers[il].rope_long; + } + + return layers[il].rope_short; +} + struct llm_build_llama : public llm_graph_context { llm_build_llama(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; @@ -4485,7 +4498,7 @@ struct llm_build_llama : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -4710,7 +4723,7 @@ struct llm_build_deci : public llm_graph_context { } else if (n_head > 0) { // self-attention // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -7192,7 +7205,7 @@ struct llm_build_phi3 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); ggml_tensor* attn_norm_output = build_norm(inpL, model.layers[il].attn_norm, @@ -7944,7 +7957,7 @@ struct llm_build_minicpm3 : public llm_graph_context { for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // norm cur = build_norm(inpL, @@ -8711,7 +8724,7 @@ struct llm_build_mamba : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto kv_head = kv_self->head; @@ -9012,7 +9025,7 @@ struct llm_build_cohere2 : public llm_graph_context { // self-attention { // rope freq factors for 128k context - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -9950,7 +9963,7 @@ struct llm_build_deepseek : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11314,7 +11327,7 @@ struct llm_build_exaone : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -11459,7 +11472,7 @@ struct llm_build_rwkv6_base : public llm_graph_context { ggml_tensor * state_mask, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -11855,7 +11868,7 @@ struct llm_build_rwkv7_base : public llm_graph_context { ggml_tensor *& first_layer_value, const llama_ubatch & ubatch, int il) const { - const llama_kv_cache_unified * kv_self = static_cast(memory); + const llama_kv_cache_recurrent * kv_self = static_cast(memory); const auto n_tokens = ubatch.n_tokens; const auto n_seqs = ubatch.n_seqs; @@ -12695,7 +12708,7 @@ struct llm_build_bailingmoe : public llm_graph_context { // self-attention { // rope freq factors for llama3; may return nullptr for llama2 and other models - ggml_tensor * rope_factors = static_cast(memory)->cbs.get_rope_factors(n_ctx_per_seq, il); + ggml_tensor * rope_factors = model.get_rope_factors(n_ctx_per_seq, il); // compute Q and K and RoPE them ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); @@ -12815,7 +12828,7 @@ struct llm_build_bailingmoe : public llm_graph_context { } }; -llama_memory_i * llama_model::create_memory() const { +llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; switch (arch) { @@ -12825,26 +12838,29 @@ llama_memory_i * llama_model::create_memory() const { case LLM_ARCH_RWKV7: case LLM_ARCH_ARWKV7: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ nullptr - }); + res = new llama_kv_cache_recurrent( + *this, + GGML_TYPE_F32, + GGML_TYPE_F32, + cparams.offload_kqv, + std::max((uint32_t) 1, cparams.n_seq_max)); } break; default: { - res = new llama_kv_cache_unified(hparams, { - /*.get_rope_factors =*/ [this](uint32_t n_ctx_per_seq, int il) { - // choose long/short freq factors based on the context size - if (layers[il].rope_freqs != nullptr) { - return layers[il].rope_freqs; - } + const auto padding = llama_kv_cache_unified::get_padding(cparams); - if (n_ctx_per_seq > hparams.n_ctx_orig_yarn) { - return layers[il].rope_long; - } + cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding); - return layers[il].rope_short; - } - }); + LLAMA_LOG_DEBUG("%s: n_ctx = %u (padded)\n", __func__, cparams.n_ctx); + + res = new llama_kv_cache_unified( + *this, + params.type_k, + params.type_v, + !cparams.flash_attn, + cparams.offload_kqv, + cparams.n_ctx, + padding); } } diff --git a/src/llama-model.h b/src/llama-model.h index 34aac337cff27..4c7e7a335b4e2 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -395,8 +395,11 @@ struct llama_model { const struct ggml_tensor * get_tensor(const char * name) const; + ggml_tensor * get_rope_factors(uint32_t n_ctx_per_seq, int il) const; + + // note: can mutate `cparams` // TODO: move this to new llm_arch_model_i interface - llama_memory_i * create_memory() const; // TODO: params + llama_memory_i * create_memory(const llama_memory_params & params, llama_cparams & cparams) const; // TODO: move this to new llm_arch_model_i interface llm_graph_result_ptr build_graph(