From 72f102a4aeb60ae7364d8ec1816e566ebbac0cc1 Mon Sep 17 00:00:00 2001 From: Claude Doppler Date: Tue, 4 Apr 2023 20:33:09 +0000 Subject: [PATCH 1/3] feat: add "stop" keywords as alternative to eot token --- examples/common.cpp | 12 +++++++++-- examples/common.h | 1 + examples/main/main.cpp | 46 +++++++++++++++++++++++++++++++++++++++--- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 7aa77587b4605..edb27b4c17d48 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -283,6 +283,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.antiprompt.push_back(argv[i]); + } else if (arg == "--stop") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.stop_keywords.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { @@ -359,8 +365,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); - fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT"); + fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); + fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); + fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/examples/common.h b/examples/common.h index 43f1cc9ef09d5..1ca16dcb62f76 100644 --- a/examples/common.h +++ b/examples/common.h @@ -50,6 +50,7 @@ struct gpt_params { std::string input_prefix = ""; // string to prefix user inputs with std::string input_suffix = ""; // string to suffix user inputs with std::vector antiprompt; // string upon seeing which more user input is prompted + std::vector stop_keywords; // string upon seeing which the model will stop std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 6e1172a48367d..713edadb62f58 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str()); } } + + if (params.stop_keywords.size()) { + for (auto stop_keyword : params.stop_keywords) { + fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); + } + } + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); @@ -513,13 +520,28 @@ int main(int argc, char ** argv) { // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) { - // check for reverse prompt - if (params.antiprompt.size()) { - std::string last_output; + std::string last_output; + if (params.antiprompt.size() || params.stop_keywords.size()) { for (auto id : last_n_tokens) { last_output += llama_token_to_str(ctx, id); } + } + + // Check for stop keywords, a configurable alternative to the end-of-text token + // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + // check for reverse prompt + if (params.antiprompt.size()) { is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & antiprompt : params.antiprompt) { @@ -586,6 +608,24 @@ int main(int argc, char ** argv) { } } + // Check for stop keywords, a configurable alternative to the end-of-text token + if (!params.interactive && params.stop_keywords.size() && !is_interacting) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + bool stop = false; + for (std::string stop_keyword : params.stop_keywords) { + if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + } + // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { if (params.instruct) { From b4d04d16131143744dcff08646c121f909574816 Mon Sep 17 00:00:00 2001 From: Claude Doppler Date: Sat, 8 Apr 2023 10:05:34 +0000 Subject: [PATCH 2/3] fix endline --- examples/common.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/common.cpp b/examples/common.cpp index edb27b4c17d48..348c605008c5c 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -365,7 +365,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n"); fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); From 2041e1e0b5f1628cee867fe276bf4a147156773f Mon Sep 17 00:00:00 2001 From: Evan Jones Date: Tue, 9 May 2023 23:16:40 -0400 Subject: [PATCH 3/3] simplify code --- examples/common.cpp | 4 ++-- examples/common.h | 8 +++---- examples/main/main.cpp | 51 ++++++++++++++++-------------------------- 3 files changed, 25 insertions(+), 38 deletions(-) diff --git a/examples/common.cpp b/examples/common.cpp index 348c605008c5c..d9a1e55b12e53 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -365,8 +365,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n"); fprintf(stderr, " --multiline-input allows you to write or paste multiple lines without ending each in '\\'\n"); fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); - fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT\n"); - fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n"); + fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); + fprintf(stderr, " specified more than once for multiple prompts).\n"); fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); diff --git a/examples/common.h b/examples/common.h index 1ca16dcb62f76..468b01b32b5a8 100644 --- a/examples/common.h +++ b/examples/common.h @@ -46,10 +46,10 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string path_session = ""; // path to file for saving/loading model eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted + std::string path_session = ""; // path to file for saving/loading model eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::vector antiprompt; // string upon seeing which more user input is prompted std::vector stop_keywords; // string upon seeing which the model will stop std::string lora_adapter = ""; // lora adapter path diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 713edadb62f58..9cc6550fd7329 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -266,7 +266,7 @@ int main(int argc, char ** argv) { } if (params.stop_keywords.size()) { - for (auto stop_keyword : params.stop_keywords) { + for (auto & stop_keyword : params.stop_keywords) { fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); } } @@ -516,22 +516,17 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } - // in interactive mode, and not currently processing queued inputs; - // check if we should prompt the user for more - if (params.interactive && (int) embd_inp.size() <= n_consumed) { - + // check for stop keywords if we're processing generations + if (params.stop_keywords.size() && (int) embd_inp.size() <= n_consumed) { std::string last_output; - if (params.antiprompt.size() || params.stop_keywords.size()) { - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); } - - // Check for stop keywords, a configurable alternative to the end-of-text token - // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM bool stop = false; - for (std::string stop_keyword : params.stop_keywords) { - if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { + for (auto & stop_keyword : params.stop_keywords) { + const size_t stop_pos = last_output.find(stop_keyword.c_str(), + last_output.length() - stop_keyword.length(), stop_keyword.length()); + if (stop_pos != std::string::npos) { stop = true; break; } @@ -539,9 +534,19 @@ int main(int argc, char ** argv) { if (stop) { break; } + } + + // in interactive mode, and not currently processing queued inputs; + // check if we should prompt the user for more + if (params.interactive && (int) embd_inp.size() <= n_consumed) { // check for reverse prompt if (params.antiprompt.size()) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + is_antiprompt = false; // Check if each of the reverse prompts appears at the end of the output. for (std::string & antiprompt : params.antiprompt) { @@ -608,24 +613,6 @@ int main(int argc, char ** argv) { } } - // Check for stop keywords, a configurable alternative to the end-of-text token - if (!params.interactive && params.stop_keywords.size() && !is_interacting) { - std::string last_output; - for (auto id : last_n_tokens) { - last_output += llama_token_to_str(ctx, id); - } - bool stop = false; - for (std::string stop_keyword : params.stop_keywords) { - if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) { - stop = true; - break; - } - } - if (stop) { - break; - } - } - // end of text token if (!embd.empty() && embd.back() == llama_token_eos()) { if (params.instruct) {