diff --git a/examples/common.cpp b/examples/common.cpp index f3085b08e5b25..c19cf92feb9b1 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -285,6 +285,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { break; } params.antiprompt.push_back(argv[i]); + } else if (arg == "--stop") { + if (++i >= argc) { + invalid_param = true; + break; + } + params.stop_keywords.push_back(argv[i]); } else if (arg == "--perplexity") { params.perplexity = true; } else if (arg == "--ignore-eos") { @@ -370,6 +376,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) { fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n"); fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n"); fprintf(stderr, " specified more than once for multiple prompts).\n"); + fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n"); + fprintf(stderr, " (can be specified more than once for multiple keywords).\n"); fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n"); fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for < 0)\n"); fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads); diff --git a/examples/common.h b/examples/common.h index 499671b2e8d6d..0c8d95ccf63c1 100644 --- a/examples/common.h +++ b/examples/common.h @@ -46,10 +46,12 @@ struct gpt_params { std::string model = "models/lamma-7B/ggml-model.bin"; // model path std::string prompt = ""; - std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state - std::string input_prefix = ""; // string to prefix user inputs with - std::string input_suffix = ""; // string to suffix user inputs with - std::vector antiprompt; // string upon seeing which more user input is prompted + + std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state + std::string input_prefix = ""; // string to prefix user inputs with + std::string input_suffix = ""; // string to suffix user inputs with + std::vector antiprompt; // string upon seeing which more user input is prompted + std::vector stop_keywords; // string upon seeing which the model will stop std::string lora_adapter = ""; // lora adapter path std::string lora_base = ""; // base model path for the lora adapter diff --git a/examples/main/main.cpp b/examples/main/main.cpp index bd1c4ab558521..11827157b40cc 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -264,6 +264,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "Input suffix: '%s'\n", params.input_suffix.c_str()); } } + + if (params.stop_keywords.size()) { + for (auto & stop_keyword : params.stop_keywords) { + fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str()); + } + } + fprintf(stderr, "sampling: repeat_last_n = %d, repeat_penalty = %f, presence_penalty = %f, frequency_penalty = %f, top_k = %d, tfs_z = %f, top_p = %f, typical_p = %f, temp = %f, mirostat = %d, mirostat_lr = %f, mirostat_ent = %f\n", params.repeat_last_n, params.repeat_penalty, params.presence_penalty, params.frequency_penalty, params.top_k, params.tfs_z, params.top_p, params.typical_p, params.temp, params.mirostat, params.mirostat_eta, params.mirostat_tau); fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep); @@ -504,6 +511,26 @@ int main(int argc, char ** argv) { console_set_color(con_st, CONSOLE_COLOR_DEFAULT); } + // check for stop keywords if we're processing generations + if (params.stop_keywords.size() && (int) embd_inp.size() <= n_consumed) { + std::string last_output; + for (auto id : last_n_tokens) { + last_output += llama_token_to_str(ctx, id); + } + bool stop = false; + for (auto & stop_keyword : params.stop_keywords) { + const size_t stop_pos = last_output.find(stop_keyword.c_str(), + last_output.length() - stop_keyword.length(), stop_keyword.length()); + if (stop_pos != std::string::npos) { + stop = true; + break; + } + } + if (stop) { + break; + } + } + // in interactive mode, and not currently processing queued inputs; // check if we should prompt the user for more if (params.interactive && (int) embd_inp.size() <= n_consumed) {