From 9770efa343c368b97df3cceb1f3c252718905060 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Apr 2025 14:13:51 +0300 Subject: [PATCH 1/7] context : allow cache-less context for embeddings ggml-ci --- examples/embedding/embedding.cpp | 2 +- src/llama-context.cpp | 11 ++++++++--- src/llama-model.cpp | 6 ++++++ 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 06fce236e2b85..cf37ffdbd1bf3 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -49,7 +49,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode(ctx, batch) < 0) { + if (llama_encode(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 45591be992d87..0ef198a2fe287 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -253,7 +253,7 @@ llama_context::llama_context( } // reserve worst-case graph - if (!hparams.vocab_only) { + if (!hparams.vocab_only && memory) { const uint32_t n_seqs = 1; // TODO: worst-case number of sequences const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); @@ -763,12 +763,12 @@ int llama_context::encode(llama_batch & inp_batch) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); - GGML_ASSERT(embd != nullptr); - switch (cparams.pooling_type) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings + GGML_ASSERT(embd != nullptr); + GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); } break; @@ -840,6 +840,11 @@ int llama_context::decode(llama_batch & inp_batch) { return -1; } + if (!memory) { + LLAMA_LOG_WARN("%s: cannot decode batches with this context\n", __func__); + return -1; + } + llama_kv_cache * kv_self = static_cast(memory.get()); // temporary allocate memory for the input batch if needed diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 08d21301374a3..1f7acba9168c5 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12833,6 +12833,12 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_memory_i * res; switch (arch) { + case LLM_ARCH_BERT: + case LLM_ARCH_JINA_BERT_V2: + case LLM_ARCH_NOMIC_BERT: + { + res = nullptr; + } break; case LLM_ARCH_MAMBA: case LLM_ARCH_RWKV6: case LLM_ARCH_RWKV6QWEN2: From a21ff6c264da82e20ee8df4a40bec4dded5e7b04 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Apr 2025 14:39:04 +0300 Subject: [PATCH 2/7] context : enable reranking with encode() ggml-ci --- src/llama-context.cpp | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 0ef198a2fe287..a140a52b8ac5a 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -793,11 +793,18 @@ int llama_context::encode(llama_batch & inp_batch) { } break; case LLAMA_POOLING_TYPE_RANK: { - // TODO: this likely should be the same logic as in llama_decoder_internal, but better to - // wait for an encoder model that requires this pooling type in order to test it - // https://github.com/ggerganov/llama.cpp/pull/9510 - GGML_ABORT("RANK pooling not implemented yet"); - } + // extract the rerank score - a single float per sequence + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs; ++s) { + const llama_seq_id seq_id = ubatch.seq_id[s][0]; + if (embd_seq_out.find(seq_id) != embd_seq_out.end()) { + continue; + } + embd_seq_out[seq_id].resize(1); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float)); + } + } break; case LLAMA_POOLING_TYPE_UNSPECIFIED: { GGML_ABORT("unknown pooling type"); From c14ee7238ac952cddcb0d85c04b16d70333fe915 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 25 Apr 2025 15:19:44 +0300 Subject: [PATCH 3/7] context : encode() clears embd_seq ggml-ci --- src/llama-context.cpp | 12 +++++++----- tools/server/server.cpp | 2 +- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a140a52b8ac5a..18cd2bf8880c7 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -702,6 +702,8 @@ int llama_context::encode(llama_batch & inp_batch) { t_compute_start_us = ggml_time_us(); } + embd_seq.clear(); + n_queued_tokens += n_tokens; const int64_t n_embd = hparams.n_embd; @@ -842,13 +844,13 @@ int llama_context::encode(llama_batch & inp_batch) { } int llama_context::decode(llama_batch & inp_batch) { - if (inp_batch.n_tokens == 0) { - LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); - return -1; - } - if (!memory) { LLAMA_LOG_WARN("%s: cannot decode batches with this context\n", __func__); + return encode(inp_batch); + } + + if (inp_batch.n_tokens == 0) { + LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } diff --git a/tools/server/server.cpp b/tools/server/server.cpp index c580ec123299c..38d6dc1ceb93d 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3941,7 +3941,7 @@ int main(int argc, char ** argv) { const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok]( server_task_type type, json & data, - std::function is_connection_closed, + const std::function & is_connection_closed, httplib::Response & res, oaicompat_type oaicompat) { GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL); From c709275663db23814ac04c1b8513573b97190a83 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 May 2025 17:46:46 +0300 Subject: [PATCH 4/7] examples : use llama_encode() when appropriate ggml-ci --- examples/embedding/embedding.cpp | 13 ++----------- tools/server/server.cpp | 9 ++++++++- 2 files changed, 10 insertions(+), 12 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index cf37ffdbd1bf3..01ff6763fff5e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - const struct llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { - // encoder-only model - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to encode\n", __func__); - } - } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { - // decoder-only model - if (llama_encode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); - } + if (llama_encode(ctx, batch) < 0) { + LOG_ERR("%s : failed to encode\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 38d6dc1ceb93d..bcfd286fdcb9a 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -3212,7 +3212,14 @@ struct server_context { batch.logits + i, }; - const int ret = llama_decode(ctx, batch_view); + int ret = 0; + + if (params_base.embedding || params_base.reranking) { + ret = llama_encode(ctx, batch_view); + } else { + ret = llama_decode(ctx, batch_view); + } + metrics.on_decoded(slots); if (ret != 0) { From 97b975d4a6cd3bc45422ac838f1b58caebcae5c4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 May 2025 18:04:48 +0300 Subject: [PATCH 5/7] models : nomic bert moe does not require KV cache --- src/llama-model.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 1f7acba9168c5..ec09754f6c8ab 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -12836,6 +12836,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_BERT: case LLM_ARCH_JINA_BERT_V2: case LLM_ARCH_NOMIC_BERT: + case LLM_ARCH_NOMIC_BERT_MOE: { res = nullptr; } break; From 3b4f6c0eca7c84358b3a4875be88101efdf5a5e8 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 May 2025 18:05:03 +0300 Subject: [PATCH 6/7] llama : update comments for llama_decode/llama_encode ggml-ci --- include/llama.h | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/include/llama.h b/include/llama.h index 06c56395c139f..a56a185966866 100644 --- a/include/llama.h +++ b/include/llama.h @@ -924,14 +924,19 @@ extern "C" { // Frees a batch of tokens allocated with llama_batch_init() LLAMA_API void llama_batch_free(struct llama_batch batch); - // Processes a batch of tokens with the ecoder part of the encoder-decoder model. - // Stores the encoder output internally for later use by the decoder cross-attention layers. + // Process a batch of tokens. + // In contrast to llama_decode() - this call does not use KV cache. + // For encode-decoder contexts, processes the batch using the encoder. + // Can store the encoder output internally for later use by the decoder's cross-attention layers. // 0 - success // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_encode( struct llama_context * ctx, struct llama_batch batch); + // Process a batch of tokens. + // Requires KV cache. + // For encode-decoder contexts, processes the batch using the decoder. // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) From abe25e7dab11803245b378d2e5ab209ed0069ca6 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 3 May 2025 18:27:28 +0300 Subject: [PATCH 7/7] context : update warning log [no ci] --- src/llama-context.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 18cd2bf8880c7..fcddeb482ebf6 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -845,7 +845,7 @@ int llama_context::encode(llama_batch & inp_batch) { int llama_context::decode(llama_batch & inp_batch) { if (!memory) { - LLAMA_LOG_WARN("%s: cannot decode batches with this context\n", __func__); + LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__); return encode(inp_batch); }