From b97df76c543855e23779e56c200ed3ab4b2b125e Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Sat, 18 Mar 2023 14:10:16 -0700 Subject: [PATCH 1/7] working but ugly --- main.cpp | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/main.cpp b/main.cpp index ca0fca8b36455..2adeb50458224 100644 --- a/main.cpp +++ b/main.cpp @@ -721,6 +721,25 @@ bool llama_eval( inpL); } + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + // capture input sentence embedding + { + std::vector embedding_representation; + embedding_representation.resize(n_embd); + memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 2)), sizeof(float) * n_embd); + fprintf(stdout, "\n[\n"); + for (int j = 0; j < embedding_representation.size()-1 ; j++){ + fprintf(stdout, "%f, ", embedding_representation[j]); + } + fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); + fprintf(stdout, "\n]\n"); + + } + + // lm_head { inpL = ggml_mul_mat(ctx0, model.output, inpL); @@ -729,9 +748,7 @@ bool llama_eval( // logits -> probs //inpL = ggml_soft_max(ctx0, inpL); - // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); + //if (n_past%100 == 0) { // ggml_graph_print (&gf); From 801071ec4f8d275cfff858057754fe5b5e1c2140 Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Sat, 18 Mar 2023 23:34:20 -0700 Subject: [PATCH 2/7] add arg flag, not working on embedding mode --- main.cpp | 124 ++++++++++++++++++++++++++++++++---------------------- utils.cpp | 2 + utils.h | 2 +- 3 files changed, 76 insertions(+), 52 deletions(-) diff --git a/main.cpp b/main.cpp index 2adeb50458224..4c2f85e239de0 100644 --- a/main.cpp +++ b/main.cpp @@ -519,6 +519,17 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab return true; } +// Prints the provided embedding vector to stdout +// in a neat format +void display_embedding(const std::vector & embedding_representation){ + fprintf(stdout, "\n[\n"); + for (int j = 0; j < embedding_representation.size()-1 ; j++){ + fprintf(stdout, "%f, ", embedding_representation[j]); + } + fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); + fprintf(stdout, "\n]\n"); +} + // evaluate the transformer // // - model: the model @@ -535,7 +546,8 @@ bool llama_eval( const int n_past, const std::vector & embd_inp, std::vector & embd_w, - size_t & mem_per_token) { + size_t & mem_per_token, + const bool embeding_mode) { const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -720,56 +732,52 @@ bool llama_eval( ggml_repeat(ctx0, model.norm, inpL), inpL); } - - // run the computation - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); - - // capture input sentence embedding - { - std::vector embedding_representation; - embedding_representation.resize(n_embd); - memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 2)), sizeof(float) * n_embd); - fprintf(stdout, "\n[\n"); - for (int j = 0; j < embedding_representation.size()-1 ; j++){ - fprintf(stdout, "%f, ", embedding_representation[j]); - } - fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); - fprintf(stdout, "\n]\n"); - - } - - // lm_head - { - inpL = ggml_mul_mat(ctx0, model.output, inpL); + if(!embeding_mode){ + // lm_head + { + inpL = ggml_mul_mat(ctx0, model.output, inpL); + } + // logits -> probs + //inpL = ggml_soft_max(ctx0, inpL); + + // run the computation + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + + //if (n_past%100 == 0) { + // ggml_graph_print (&gf); + // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); + //} + + //embd_w.resize(n_vocab*N); + //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); + + // return result for just the last token + embd_w.resize(n_vocab); + memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + if (mem_per_token == 0) { + mem_per_token = ggml_used_mem(ctx0)/N; + } + //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); + + ggml_free(ctx0); + return true; + } else { + // capture input sentence embedding + ggml_build_forward_expand(&gf, inpL); + ggml_graph_compute (ctx0, &gf); + printf("Compute went ok\n"); + std::vector embedding_representation; + embedding_representation.resize(n_embd); + memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); + printf("About to display\n"); + display_embedding(embedding_representation); + printf("About to free\n"); + ggml_free(ctx0); + return true; } - - // logits -> probs - //inpL = ggml_soft_max(ctx0, inpL); - - - //if (n_past%100 == 0) { - // ggml_graph_print (&gf); - // ggml_graph_dump_dot(&gf, NULL, "gpt-2.dot"); - //} - - //embd_w.resize(n_vocab*N); - //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - - // return result for just the last token - embd_w.resize(n_vocab); - memcpy(embd_w.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); - - if (mem_per_token == 0) { - mem_per_token = ggml_used_mem(ctx0)/N; - } - //fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0)); - - ggml_free(ctx0); - - return true; } static bool is_interacting = false; @@ -906,13 +914,12 @@ int main(int argc, char ** argv) { // determine the required inference memory per token: size_t mem_per_token = 0; - llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); + llama_eval(model, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, false); int last_n_size = params.repeat_last_n; std::vector last_n_tokens(last_n_size); std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0); - if (params.interactive) { fprintf(stderr, "== Running in interactive mode. ==\n" #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) @@ -936,12 +943,27 @@ int main(int argc, char ** argv) { printf(ANSI_COLOR_YELLOW); } + if (params.embedding){ + printf("got right before second call.\n"); + const int64_t t_start_us = ggml_time_us(); //HERE + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { + fprintf(stderr, "Failed to predict\n"); + return 1; + } + //ggml_free(model.ctx); + + if (params.use_color) { + printf(ANSI_COLOR_RESET); + } + return 0; + } + while (remaining_tokens > 0) { // predict if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token)) { + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, false)) { fprintf(stderr, "Failed to predict\n"); return 1; } diff --git a/utils.cpp b/utils.cpp index aa3ad1053da02..20b6a86ce341c 100644 --- a/utils.cpp +++ b/utils.cpp @@ -53,6 +53,8 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { params.model = argv[++i]; } else if (arg == "-i" || arg == "--interactive") { params.interactive = true; + } else if (arg == "--embedding") { + params.embedding = true; } else if (arg == "--interactive-start") { params.interactive = true; params.interactive_start = true; diff --git a/utils.h b/utils.h index 021120b0513c7..dca497d063d84 100644 --- a/utils.h +++ b/utils.h @@ -31,7 +31,7 @@ struct gpt_params { std::string prompt; bool use_color = false; // use color to distinguish generations and inputs - + bool embedding = false; // get only sentence embedding bool interactive = false; // interactive mode bool interactive_start = false; // reverse prompt immediately std::string antiprompt = ""; // string upon seeing which more user input is prompted From d2b1d3a439feca392a1c29f3be8333b695668309 Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Sat, 18 Mar 2023 23:36:36 -0700 Subject: [PATCH 3/7] typo --- main.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/main.cpp b/main.cpp index a824d46c62826..4c0806168b5e1 100644 --- a/main.cpp +++ b/main.cpp @@ -539,7 +539,7 @@ bool llama_eval( const std::vector & embd_inp, std::vector & embd_w, size_t & mem_per_token, - const bool embeding_mode) { + const bool embedding_mode) { const int N = embd_inp.size(); const auto & hparams = model.hparams; @@ -725,7 +725,7 @@ bool llama_eval( inpL); } - if(!embeding_mode){ + if(!embedding_mode){ // lm_head { inpL = ggml_mul_mat(ctx0, model.output, inpL); From 76dde268447ad900897373c6454f51bc9c686ae5 Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Tue, 21 Mar 2023 18:32:51 -0700 Subject: [PATCH 4/7] Working! Thanks to @nullhook --- main.cpp | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/main.cpp b/main.cpp index 4c0806168b5e1..7c1e108aa2044 100644 --- a/main.cpp +++ b/main.cpp @@ -759,13 +759,10 @@ bool llama_eval( // capture input sentence embedding ggml_build_forward_expand(&gf, inpL); ggml_graph_compute (ctx0, &gf); - printf("Compute went ok\n"); std::vector embedding_representation; embedding_representation.resize(n_embd); memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); - printf("About to display\n"); display_embedding(embedding_representation); - printf("About to free\n"); ggml_free(ctx0); return true; } @@ -943,13 +940,14 @@ int main(int argc, char ** argv) { } if (params.embedding){ - printf("got right before second call.\n"); - const int64_t t_start_us = ggml_time_us(); //HERE - if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { - fprintf(stderr, "Failed to predict\n"); - return 1; + embd = embd_inp; + if (embd.size() > 0) { + const int64_t t_start_us = ggml_time_us(); + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { + fprintf(stderr, "Failed to predict\n"); + return 1; + } } - //ggml_free(model.ctx); if (params.use_color) { printf(ANSI_COLOR_RESET); From 78cff5842756a4474b55a4c1f4de6c1f628ef7cd Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Tue, 21 Mar 2023 23:21:07 -0700 Subject: [PATCH 5/7] make params argument instead of hardcoded boolean. remove useless time check --- main.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/main.cpp b/main.cpp index a7b283baf63ea..af7b41fc4343a 100644 --- a/main.cpp +++ b/main.cpp @@ -1117,8 +1117,7 @@ int main(int argc, char ** argv) { if (params.embedding){ embd = embd_inp; if (embd.size() > 0) { - const int64_t t_start_us = ggml_time_us(); - if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, params.embedding)) { fprintf(stderr, "Failed to predict\n"); return 1; } @@ -1135,7 +1134,7 @@ int main(int argc, char ** argv) { if (embd.size() > 0) { const int64_t t_start_us = ggml_time_us(); - if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, false)) { + if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, params.embedding)) { fprintf(stderr, "Failed to predict\n"); return 1; } From 859e70899a25abcd08e190890da05c201d9bb72b Mon Sep 17 00:00:00 2001 From: strikingLoo Date: Wed, 22 Mar 2023 17:52:46 -0700 Subject: [PATCH 6/7] start doing the instructions but not finished. This probably doesnt compile --- llama.cpp | 16 ++++++++++++++-- llama.h | 5 +++++ main.cpp | 1 + 3 files changed, 20 insertions(+), 2 deletions(-) diff --git a/llama.cpp b/llama.cpp index 77680f46edf14..111801e893f02 100644 --- a/llama.cpp +++ b/llama.cpp @@ -101,6 +101,8 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; + // input embedding (1-dimensional array: [n_embd]) + std::vector embedding; bool logits_all = false; }; @@ -112,6 +114,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, + /*.embedding =*/ false, }; return result; @@ -127,7 +130,8 @@ static bool llama_model_load( int n_ctx, int n_parts, ggml_type memory_type, - bool vocab_only) { + bool vocab_only, + bool embedding) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -594,6 +598,10 @@ static bool llama_model_load( lctx.logits.reserve(lctx.model.hparams.n_ctx); + if (embedding){ + lctx.embedding.reserve(lctx.model.hparams.n_embd); + } + lctx.t_load_us = ggml_time_us() - t_start_us; return true; @@ -1433,7 +1441,7 @@ struct llama_context * llama_init_from_file( ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) { + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) { fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; @@ -1508,6 +1516,10 @@ float * llama_get_logits(struct llama_context * ctx) { return ctx->logits.data(); } +float * llama_get_embeddings(struct llama_context * ctx) { + return ctx->embedding.data(); +} + const char * llama_token_to_str(struct llama_context * ctx, llama_token token) { if (token >= llama_n_vocab(ctx)) { return nullptr; diff --git a/llama.h b/llama.h index 0fc5438a87823..393a896eb6672 100644 --- a/llama.h +++ b/llama.h @@ -53,6 +53,7 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights + bool embedding; // embedding mode only }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -109,6 +110,10 @@ extern "C" { // Cols: n_vocab LLAMA_API float * llama_get_logits(struct llama_context * ctx); + // Get the embeddings for the input + // shape: [n_embd] (1-dimensional) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx) + // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); diff --git a/main.cpp b/main.cpp index 44b4cec28a324..8a639660c9bd2 100644 --- a/main.cpp +++ b/main.cpp @@ -199,6 +199,7 @@ int main(int argc, char ** argv) { lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; lparams.logits_all = params.perplexity; + lparams.embedding = params.embedding; ctx = llama_init_from_file(params.model.c_str(), lparams); From 8a3c34bb54f0675f693846efc2ffac010f40fbfc Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Thu, 23 Mar 2023 22:02:14 +0200 Subject: [PATCH 7/7] Embeddings extraction support --- llama.cpp | 92 +++++++++++++++++++++++++++---------------------------- llama.h | 7 ++--- main.cpp | 16 +++++++--- 3 files changed, 59 insertions(+), 56 deletions(-) diff --git a/llama.cpp b/llama.cpp index 111801e893f02..8a3f8b53f9356 100644 --- a/llama.cpp +++ b/llama.cpp @@ -101,9 +101,10 @@ struct llama_context { // decode output (2-dimensional array: [n_tokens][n_vocab]) std::vector logits; + bool logits_all = false; + // input embedding (1-dimensional array: [n_embd]) std::vector embedding; - bool logits_all = false; }; struct llama_context_params llama_context_default_params() { @@ -114,7 +115,7 @@ struct llama_context_params llama_context_default_params() { /*.f16_kv =*/ false, /*.logits_all =*/ false, /*.vocab_only =*/ false, - /*.embedding =*/ false, + /*.embedding =*/ false, }; return result; @@ -130,8 +131,7 @@ static bool llama_model_load( int n_ctx, int n_parts, ggml_type memory_type, - bool vocab_only, - bool embedding) { + bool vocab_only) { fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str()); const int64_t t_start_us = ggml_time_us(); @@ -596,29 +596,11 @@ static bool llama_model_load( fin.close(); } - lctx.logits.reserve(lctx.model.hparams.n_ctx); - - if (embedding){ - lctx.embedding.reserve(lctx.model.hparams.n_embd); - } - lctx.t_load_us = ggml_time_us() - t_start_us; return true; } -// Prints the provided embedding vector to stdout -// in a neat format -void display_embedding(const std::vector & embedding_representation){ - fprintf(stdout, "\n[\n"); - for (int j = 0; j < embedding_representation.size()-1 ; j++){ - fprintf(stdout, "%f, ", embedding_representation[j]); - } - fprintf(stdout, "%f", embedding_representation[embedding_representation.size()-1]); - fprintf(stdout, "\n]\n"); -} - - // evaluate the transformer // // - lctx: llama context @@ -631,8 +613,7 @@ static bool llama_eval_internal( const llama_token * tokens, const int n_tokens, const int n_past, - const int n_threads, - const bool embedding_mode = false) { + const int n_threads) { const int64_t t_start_us = ggml_time_us(); const int N = n_tokens; @@ -810,6 +791,9 @@ static bool llama_eval_internal( inpL = cur; } + // used at the end to optionally extract the embeddings + struct ggml_tensor * embeddings = NULL; + // norm { inpL = ggml_rms_norm(ctx0, inpL); @@ -818,18 +802,8 @@ static bool llama_eval_internal( inpL = ggml_mul(ctx0, ggml_repeat(ctx0, model.norm, inpL), inpL); - } - if(embedding_mode){ - // capture input sentence embedding - ggml_build_forward_expand(&gf, inpL); - ggml_graph_compute (ctx0, &gf); - std::vector embedding_representation; - embedding_representation.resize(n_embd); - memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); - display_embedding(embedding_representation); - ggml_free(ctx0); - return true; + embeddings = inpL; } // lm_head @@ -852,15 +826,26 @@ static bool llama_eval_internal( //embd_w.resize(n_vocab*N); //memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N); - auto & logits_out = lctx.logits; + // extract logits + { + auto & logits_out = lctx.logits; + + if (lctx.logits_all) { + logits_out.resize(n_vocab * N); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); + } else { + // return result for just the last token + logits_out.resize(n_vocab); + memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + } + } + + // extract embeddings + if (lctx.embedding.size()) { + auto & embedding_out = lctx.embedding; - if (lctx.logits_all) { - logits_out.resize(n_vocab * N); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N); - } else { - // return result for just the last token - logits_out.resize(n_vocab); - memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab); + embedding_out.resize(n_embd); + memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd); } if (mem_per_token == 0) { @@ -1441,12 +1426,26 @@ struct llama_context * llama_init_from_file( ggml_type type_memory = params.f16_kv ? GGML_TYPE_F16 : GGML_TYPE_F32; - if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only, params.embedding)) { + if (!llama_model_load(path_model, *ctx, params.n_ctx, params.n_parts, type_memory, params.vocab_only)) { fprintf(stderr, "%s: failed to load model\n", __func__); delete ctx; return nullptr; } + // reserve memory for context buffers + { + const auto & hparams = ctx->model.hparams; + if (params.logits_all) { + ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); + } else { + ctx->logits.reserve(hparams.n_ctx); + } + + if (params.embedding){ + ctx->embedding.reserve(hparams.n_embd); + } + } + return ctx; } @@ -1474,9 +1473,8 @@ int llama_eval( const llama_token * tokens, int n_tokens, int n_past, - int n_threads, - bool embedding_mode = false) { - if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads, embedding_mode)) { + int n_threads) { + if (!llama_eval_internal(*ctx, tokens, n_tokens, n_past, n_threads)) { fprintf(stderr, "%s: failed to eval\n", __func__); return 1; } diff --git a/llama.h b/llama.h index 393a896eb6672..209b4dbe81d6a 100644 --- a/llama.h +++ b/llama.h @@ -53,7 +53,7 @@ extern "C" { bool f16_kv; // use fp16 for KV cache bool logits_all; // the llama_eval() call computes all logits, not just the last one bool vocab_only; // only load the vocabulary, no weights - bool embedding; // embedding mode only + bool embedding; // embedding mode only }; LLAMA_API struct llama_context_params llama_context_default_params(); @@ -85,8 +85,7 @@ extern "C" { const llama_token * tokens, int n_tokens, int n_past, - int n_threads, - bool embedding_mode); + int n_threads); // Convert the provided text into tokens. // The tokens pointer must be large enough to hold the resulting tokens. @@ -112,7 +111,7 @@ extern "C" { // Get the embeddings for the input // shape: [n_embd] (1-dimensional) - LLAMA_API float * llama_get_embeddings(struct llama_context * ctx) + LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); // Token Id -> String. Uses the vocabulary in the provided context LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token); diff --git a/main.cpp b/main.cpp index 8a639660c9bd2..dd2709f543b13 100644 --- a/main.cpp +++ b/main.cpp @@ -98,7 +98,7 @@ void perplexity(llama_context * ctx, const gpt_params & params) { int end = start + params.n_ctx - 1; std::vector embd(tokens.begin() + start, tokens.begin() + end); auto start_t = std::chrono::high_resolution_clock::now(); - if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads, false)) { + if (llama_eval(ctx, embd.data(), embd.size(), 0, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return; } @@ -220,7 +220,7 @@ int main(int argc, char ** argv) { // TODO: better way to do that { const std::vector tmp = { 0, 1, 2, 3 }; - llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads, false); + llama_eval(ctx, tmp.data(), tmp.size(), 0, params.n_threads); } if (params.perplexity) { @@ -302,7 +302,7 @@ int main(int argc, char ** argv) { #endif " - Press Return to return control to LLaMa.\n" " - If you want to submit another line, end your input in '\\'.\n\n"); - is_interacting = params.interactive_start; + is_interacting = params.interactive_start || params.instruct; } int input_consumed = 0; @@ -325,23 +325,29 @@ int main(int argc, char ** argv) { if (params.embedding){ embd = embd_inp; + if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; } } + const auto embeddings = llama_get_embeddings(ctx); + + // TODO: print / use the embeddings + if (params.use_color) { printf(ANSI_COLOR_RESET); } + return 0; } while (remaining_tokens > 0 || params.interactive) { // predict if (embd.size() > 0) { - if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads, params.embedding)) { + if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) { fprintf(stderr, "%s : failed to eval\n", __func__); return 1; }