Skip to content

Add embedding mode with arg flag. Currently working #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 11 commits into from
Mar 24, 2023
56 changes: 46 additions & 10 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,9 @@ struct llama_context {
// decode output (2-dimensional array: [n_tokens][n_vocab])
std::vector<float> logits;
bool logits_all = false;

// input embedding (1-dimensional array: [n_embd])
std::vector<float> embedding;
};

struct llama_context_params llama_context_default_params() {
Expand All @@ -112,6 +115,7 @@ struct llama_context_params llama_context_default_params() {
/*.f16_kv =*/ false,
/*.logits_all =*/ false,
/*.vocab_only =*/ false,
/*.embedding =*/ false,
};

return result;
Expand Down Expand Up @@ -592,8 +596,6 @@ static bool llama_model_load(
fin.close();
}

lctx.logits.reserve(lctx.model.hparams.n_ctx);

lctx.t_load_us = ggml_time_us() - t_start_us;

return true;
Expand Down Expand Up @@ -789,6 +791,9 @@ static bool llama_eval_internal(
inpL = cur;
}

// used at the end to optionally extract the embeddings
struct ggml_tensor * embeddings = NULL;

// norm
{
inpL = ggml_rms_norm(ctx0, inpL);
Expand All @@ -797,6 +802,8 @@ static bool llama_eval_internal(
inpL = ggml_mul(ctx0,
ggml_repeat(ctx0, model.norm, inpL),
inpL);

embeddings = inpL;
}

// lm_head
Expand All @@ -819,15 +826,26 @@ static bool llama_eval_internal(
//embd_w.resize(n_vocab*N);
//memcpy(embd_w.data(), ggml_get_data(inpL), sizeof(float)*n_vocab*N);

auto & logits_out = lctx.logits;
// extract logits
{
auto & logits_out = lctx.logits;

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
}
}

// extract embeddings
if (lctx.embedding.size()) {
auto & embedding_out = lctx.embedding;

if (lctx.logits_all) {
logits_out.resize(n_vocab * N);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL), sizeof(float)*n_vocab*N);
} else {
// return result for just the last token
logits_out.resize(n_vocab);
memcpy(logits_out.data(), (float *) ggml_get_data(inpL) + (n_vocab*(N-1)), sizeof(float)*n_vocab);
embedding_out.resize(n_embd);
memcpy(embedding_out.data(), (float *) ggml_get_data(embeddings) + (n_embd*(N - 1)), sizeof(float)*n_embd);
}

if (mem_per_token == 0) {
Expand Down Expand Up @@ -1414,6 +1432,20 @@ struct llama_context * llama_init_from_file(
return nullptr;
}

// reserve memory for context buffers
{
const auto & hparams = ctx->model.hparams;
if (params.logits_all) {
ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab);
} else {
ctx->logits.reserve(hparams.n_ctx);
}

if (params.embedding){
ctx->embedding.reserve(hparams.n_embd);
}
}

return ctx;
}

Expand Down Expand Up @@ -1482,6 +1514,10 @@ float * llama_get_logits(struct llama_context * ctx) {
return ctx->logits.data();
}

float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data();
}

const char * llama_token_to_str(struct llama_context * ctx, llama_token token) {
if (token >= llama_n_vocab(ctx)) {
return nullptr;
Expand Down
5 changes: 5 additions & 0 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ extern "C" {
bool f16_kv; // use fp16 for KV cache
bool logits_all; // the llama_eval() call computes all logits, not just the last one
bool vocab_only; // only load the vocabulary, no weights
bool embedding; // embedding mode only
};

LLAMA_API struct llama_context_params llama_context_default_params();
Expand Down Expand Up @@ -108,6 +109,10 @@ extern "C" {
// Cols: n_vocab
LLAMA_API float * llama_get_logits(struct llama_context * ctx);

// Get the embeddings for the input
// shape: [n_embd] (1-dimensional)
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);

// Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(struct llama_context * ctx, llama_token token);

Expand Down
25 changes: 24 additions & 1 deletion main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,7 @@ int main(int argc, char ** argv) {
lparams.seed = params.seed;
lparams.f16_kv = params.memory_f16;
lparams.logits_all = params.perplexity;
lparams.embedding = params.embedding;

ctx = llama_init_from_file(params.model.c_str(), lparams);

Expand Down Expand Up @@ -289,6 +290,7 @@ int main(int argc, char ** argv) {

std::vector<llama_token> embd;


int last_n_size = params.repeat_last_n;
std::vector<llama_token> last_n_tokens(last_n_size);
std::fill(last_n_tokens.begin(), last_n_tokens.end(), 0);
Expand All @@ -300,7 +302,7 @@ int main(int argc, char ** argv) {
#endif
" - Press Return to return control to LLaMa.\n"
" - If you want to submit another line, end your input in '\\'.\n\n");
is_interacting = params.interactive_start;
is_interacting = params.interactive_start || params.instruct;
}

int input_consumed = 0;
Expand All @@ -321,6 +323,27 @@ int main(int argc, char ** argv) {
// the first thing we will do is to output the prompt, so set color accordingly
set_console_state(CONSOLE_STATE_PROMPT);

if (params.embedding){
embd = embd_inp;

if (embd.size() > 0) {
if (llama_eval(ctx, embd.data(), embd.size(), n_past, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
}

const auto embeddings = llama_get_embeddings(ctx);

// TODO: print / use the embeddings

if (params.use_color) {
printf(ANSI_COLOR_RESET);
}

return 0;
}

while (remaining_tokens > 0 || params.interactive) {
// predict
if (embd.size() > 0) {
Expand Down
4 changes: 4 additions & 0 deletions utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.model = argv[++i];
} else if (arg == "-i" || arg == "--interactive") {
params.interactive = true;
} else if (arg == "--embedding") {
params.embedding = true;
} else if (arg == "--interactive-start") {
params.interactive = true;
} else if (arg == "--interactive-first") {
params.interactive_start = true;
} else if (arg == "-ins" || arg == "--instruct") {
Expand Down
4 changes: 4 additions & 0 deletions utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,17 @@ struct gpt_params {
std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";


std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

bool memory_f16 = false; // use f16 instead of f32 for memory kv
bool random_prompt = false; // do not randomize prompt if none provided
bool use_color = false; // use color to distinguish generations and inputs
bool interactive = false; // interactive mode

bool embedding = false; // get only sentence embedding
bool interactive_start = false; // wait for user input immediately

bool instruct = false; // instruction mode (used for Alpaca models)
bool ignore_eos = false; // do not stop generating after eos
bool perplexity = false; // compute perplexity over the prompt
Expand Down