diff --git a/common/arg.cpp b/common/arg.cpp index b1754f30fca91..997f732cc472a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1678,7 +1678,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.warmup = false; } - ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL})); add_opt(common_arg( {"--spm-infill"}, string_format( diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 0efe20d4b3f5d..e3d0c9542ed8d 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -81,14 +81,14 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { +static void batch_encode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); + if (llama_encode(ctx, batch) < 0) { + LOG_ERR("%s : failed to encode\n", __func__); } for (int i = 0; i < batch.n_tokens; i++) { @@ -233,7 +233,7 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch) { float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_encode(ctx, batch, out, s, n_embd); common_batch_clear(batch); p += s; s = 0; @@ -246,7 +246,7 @@ int main(int argc, char ** argv) { // final batch float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_encode(ctx, batch, out, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -267,7 +267,7 @@ int main(int argc, char ** argv) { batch_add_seq(query_batch, query_tokens, 0); std::vector query_emb(n_embd, 0); - batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_encode(ctx, query_batch, query_emb.data(), 1, n_embd); common_batch_clear(query_batch);