Skip to content

main : add stop keywords #1387

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions examples/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break;
}
params.antiprompt.push_back(argv[i]);
} else if (arg == "--stop") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.stop_keywords.push_back(argv[i]);
} else if (arg == "--perplexity") {
params.perplexity = true;
} else if (arg == "--ignore-eos") {
Expand Down Expand Up @@ -370,6 +376,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
fprintf(stderr, " specified more than once for multiple prompts).\n");
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n");
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);
Expand Down
10 changes: 6 additions & 4 deletions examples/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,12 @@ struct gpt_params {

std::string model = "models/lamma-7B/ggml-model.bin"; // model path
std::string prompt = "";
std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted

std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state
std::string input_prefix = ""; // string to prefix user inputs with
std::string input_suffix = ""; // string to suffix user inputs with
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop

std::string lora_adapter = ""; // lora adapter path
std::string lora_base = ""; // base model path for the lora adapter
Expand Down
27 changes: 27 additions & 0 deletions examples/main/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,13 @@ int main(int argc, char ** argv) {
fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str());
}
}

if (params.stop_keywords.size()) {
for (auto & stop_keyword : params.stop_keywords) {
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
}
}

fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n",
params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau);
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
Expand Down Expand Up @@ -504,6 +511,26 @@ int main(int argc, char ** argv) {
console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
}

// check for stop keywords if we're processing generations
if (params.stop_keywords.size() && (int) embd_inp.size() <= n_consumed) {
std::string last_output;
for (auto id : last_n_tokens) {
last_output += llama_token_to_str(ctx, id);
}
bool stop = false;
for (auto & stop_keyword : params.stop_keywords) {
const size_t stop_pos = last_output.find(stop_keyword.c_str(),
last_output.length() - stop_keyword.length(), stop_keyword.length());
if (stop_pos != std::string::npos) {
stop = true;
break;
}
}
if (stop) {
break;
}
}

// in interactive mode, and not currently processing queued inputs;
// check if we should prompt the user for more
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
Expand Down