Skip to content

[Fix] Reenable server embedding endpoint #1937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions examples/server/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Command line options:
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
- `--embedding`: Enable embedding extraction, Default: disabled.

## Build

Expand Down Expand Up @@ -119,14 +120,14 @@ node .

`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).

`n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. (default: 128, -1 = infinity).
`n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. When 0, no tokens will be generated but the prompt is evaluated into the cache. (default: 128, -1 = infinity).

`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.

`stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.

`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate.
`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. A space is inserted in the front like main.cpp does.

`stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
Expand Down Expand Up @@ -163,6 +164,14 @@ node .

`content`: Set the text to tokenize.

Note that the special `BOS` token is not added in fron of the text and also a space character is not inserted automatically as it is for `/completion`.

- **POST** `/embedding`: Generate embedding of a given text just as [the embedding example](../embedding) does.

*Options:*

`content`: Set the text to process.

## More examples

### Interactive mode
Expand Down
44 changes: 43 additions & 1 deletion examples/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,11 @@ struct llama_server_context {
n_past += n_eval;
}

if (params.n_predict == 0) {
has_next_token = false;
return llama_token_eos();
}

// out of user input, sample next token
const float temp = params.temp;
const int32_t top_k = params.top_k <= 0 ? llama_n_vocab(ctx) : params.top_k;
Expand Down Expand Up @@ -419,6 +424,19 @@ struct llama_server_context {

return token_text;
}

std::vector<float> getEmbedding() {
static const int n_embd = llama_n_embd(ctx);
if (!params.embedding) {
LOG_WARNING("embedding disabled", {
{ "params.embedding", params.embedding },
});
return std::vector<float>(n_embd, 0.0f);
}
const float * data = llama_get_embeddings(ctx);
std::vector<float> embedding(data, data + n_embd);
return embedding;
}
};

static void server_print_usage(const char * argv0, const gpt_params & params,
Expand Down Expand Up @@ -457,6 +475,7 @@ static void server_print_usage(const char * argv0, const gpt_params & params,
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
fprintf(stderr, " -to N, --timeout N server read/write timeout in seconds (default: %d)\n", sparams.read_timeout);
fprintf(stderr, " --embedding enable embedding vector output (default: %s)\n", params.embedding ? "enabled" : "disabled");
fprintf(stderr, "\n");
}

Expand Down Expand Up @@ -603,6 +622,8 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
params.use_mlock = true;
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--embedding") {
params.embedding = true;
} else {
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
server_print_usage(argv[0], default_params, default_sparams);
Expand Down Expand Up @@ -646,6 +667,12 @@ static json format_generation_settings(llama_server_context & llama) {
};
}

static json format_embedding_response(llama_server_context & llama) {
return json {
{ "embedding", llama.getEmbedding() },
};
}

static json format_final_response(llama_server_context & llama, const std::string & content) {
return json {
{ "content", content },
Expand Down Expand Up @@ -881,12 +908,27 @@ int main(int argc, char ** argv) {

svr.Post("/tokenize", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body);
const std::string content = body["content"].get<std::string>();
const std::string content = body.value("content", "");
const std::vector<llama_token> tokens = llama_tokenize(llama.ctx, content, false);
const json data = format_tokenizer_response(tokens);
return res.set_content(data.dump(), "application/json");
});

svr.Post("/embedding", [&llama](const Request & req, Response & res) {
const json body = json::parse(req.body);

llama.rewind();
llama_reset_timings(llama.ctx);
llama.params.prompt = body.value("content", "");
llama.params.n_predict = 0;
llama.loadPrompt();
llama.beginCompletion();
llama.doCompletion();

const json data = format_embedding_response(llama);
return res.set_content(data.dump(), "application/json");
});

svr.set_logger(log_server_request);

svr.set_exception_handler([](const Request &, Response & res, std::exception_ptr ep) {
Expand Down