Skip to content

context : allow cache-less context for embeddings #13108

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 2 additions & 11 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke

static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
const struct llama_model * model = llama_get_model(ctx);

// clear previous kv_cache values (irrelevant for embeddings)
llama_kv_self_clear(ctx);

// run model
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) {
// encoder-only model
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}
} else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) {
// decoder-only model
if (llama_decode(ctx, batch) < 0) {
LOG_ERR("%s : failed to decode\n", __func__);
}
if (llama_encode(ctx, batch) < 0) {
LOG_ERR("%s : failed to encode\n", __func__);
}

for (int i = 0; i < batch.n_tokens; i++) {
Expand Down
9 changes: 7 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -924,14 +924,19 @@ extern "C" {
// Frees a batch of tokens allocated with llama_batch_init()
LLAMA_API void llama_batch_free(struct llama_batch batch);

// Processes a batch of tokens with the ecoder part of the encoder-decoder model.
// Stores the encoder output internally for later use by the decoder cross-attention layers.
// Process a batch of tokens.
// In contrast to llama_decode() - this call does not use KV cache.
// For encode-decoder contexts, processes the batch using the encoder.
// Can store the encoder output internally for later use by the decoder's cross-attention layers.
// 0 - success
// < 0 - error. the KV cache state is restored to the state before this call
LLAMA_API int32_t llama_encode(
struct llama_context * ctx,
struct llama_batch batch);

// Process a batch of tokens.
// Requires KV cache.
// For encode-decoder contexts, processes the batch using the decoder.
// Positive return values does not mean a fatal error, but rather a warning.
// 0 - success
// 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
Expand Down
30 changes: 22 additions & 8 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ llama_context::llama_context(
}

// reserve worst-case graph
if (!hparams.vocab_only) {
if (!hparams.vocab_only && memory) {
const uint32_t n_seqs = 1; // TODO: worst-case number of sequences
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);

Expand Down Expand Up @@ -702,6 +702,8 @@ int llama_context::encode(llama_batch & inp_batch) {
t_compute_start_us = ggml_time_us();
}

embd_seq.clear();

n_queued_tokens += n_tokens;

const int64_t n_embd = hparams.n_embd;
Expand Down Expand Up @@ -763,12 +765,12 @@ int llama_context::encode(llama_batch & inp_batch) {
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
GGML_ASSERT(backend_embd != nullptr);

GGML_ASSERT(embd != nullptr);

switch (cparams.pooling_type) {
case LLAMA_POOLING_TYPE_NONE:
{
// extract token embeddings
GGML_ASSERT(embd != nullptr);

GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float));
} break;
Expand All @@ -793,11 +795,18 @@ int llama_context::encode(llama_batch & inp_batch) {
} break;
case LLAMA_POOLING_TYPE_RANK:
{
// TODO: this likely should be the same logic as in llama_decoder_internal, but better to
// wait for an encoder model that requires this pooling type in order to test it
// https://github.com/ggerganov/llama.cpp/pull/9510
GGML_ABORT("RANK pooling not implemented yet");
}
// extract the rerank score - a single float per sequence
auto & embd_seq_out = embd_seq;

for (uint32_t s = 0; s < ubatch.n_seqs; ++s) {
const llama_seq_id seq_id = ubatch.seq_id[s][0];
if (embd_seq_out.find(seq_id) != embd_seq_out.end()) {
continue;
}
embd_seq_out[seq_id].resize(1);
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (seq_id)*sizeof(float), sizeof(float));
}
} break;
case LLAMA_POOLING_TYPE_UNSPECIFIED:
{
GGML_ABORT("unknown pooling type");
Expand Down Expand Up @@ -835,6 +844,11 @@ int llama_context::encode(llama_batch & inp_batch) {
}

int llama_context::decode(llama_batch & inp_batch) {
if (!memory) {
LLAMA_LOG_WARN("%s: cannot decode batches with this context (use llama_encode() instead)\n", __func__);
return encode(inp_batch);
}

if (inp_batch.n_tokens == 0) {
LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__);
return -1;
Expand Down
7 changes: 7 additions & 0 deletions src/llama-model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12833,6 +12833,13 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
llama_memory_i * res;

switch (arch) {
case LLM_ARCH_BERT:
case LLM_ARCH_JINA_BERT_V2:
case LLM_ARCH_NOMIC_BERT:
case LLM_ARCH_NOMIC_BERT_MOE:
{
res = nullptr;
} break;
case LLM_ARCH_MAMBA:
case LLM_ARCH_RWKV6:
case LLM_ARCH_RWKV6QWEN2:
Expand Down
11 changes: 9 additions & 2 deletions tools/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3212,7 +3212,14 @@ struct server_context {
batch.logits + i,
};

const int ret = llama_decode(ctx, batch_view);
int ret = 0;

if (params_base.embedding || params_base.reranking) {
ret = llama_encode(ctx, batch_view);
} else {
ret = llama_decode(ctx, batch_view);
}

metrics.on_decoded(slots);

if (ret != 0) {
Expand Down Expand Up @@ -3941,7 +3948,7 @@ int main(int argc, char ** argv) {
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
server_task_type type,
json & data,
std::function<bool()> is_connection_closed,
const std::function<bool()> & is_connection_closed,
httplib::Response & res,
oaicompat_type oaicompat) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
Expand Down