Skip to content

Commit 9fd062f

Browse files
author
Claude Doppler
committed
feat: add "stop" keywords as alternative to eot token
1 parent 62cfc54 commit 9fd062f

File tree

3 files changed

+54
-5
lines changed

3 files changed

+54
-5
lines changed

examples/common.cpp

+10-2
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,12 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
164164
break;
165165
}
166166
params.antiprompt.push_back(argv[i]);
167+
} else if (arg == "--stop") {
168+
if (++i >= argc) {
169+
invalid_param = true;
170+
break;
171+
}
172+
params.stop_keywords.push_back(argv[i]);
167173
} else if (arg == "--perplexity") {
168174
params.perplexity = true;
169175
} else if (arg == "--ignore-eos") {
@@ -209,8 +215,10 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
209215
fprintf(stderr, " --interactive-first run in interactive mode and wait for input right away\n");
210216
fprintf(stderr, " -ins, --instruct run in instruction mode (use with Alpaca models)\n");
211217
fprintf(stderr, " -r PROMPT, --reverse-prompt PROMPT\n");
212-
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT (can be\n");
213-
fprintf(stderr, " specified more than once for multiple prompts).\n");
218+
fprintf(stderr, " run in interactive mode and poll user input upon seeing PROMPT");
219+
fprintf(stderr, " (can be specified more than once for multiple reverse prompts).\n");
220+
fprintf(stderr, " --stop KEYWORD a string that, when output by the model, will stop generation\n");
221+
fprintf(stderr, " (can be specified more than once for multiple keywords).\n");
214222
fprintf(stderr, " --color colorise output to distinguish prompt and user input from generations\n");
215223
fprintf(stderr, " -s SEED, --seed SEED RNG seed (default: -1, use random seed for <= 0)\n");
216224
fprintf(stderr, " -t N, --threads N number of threads to use during computation (default: %d)\n", params.n_threads);

examples/common.h

+1
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ struct gpt_params {
3535

3636

3737
std::vector<std::string> antiprompt; // string upon seeing which more user input is prompted
38+
std::vector<std::string> stop_keywords; // string upon seeing which the model will stop
3839

3940
bool memory_f16 = true; // use f16 instead of f32 for memory kv
4041
bool random_prompt = false; // do not randomize prompt if none provided

examples/main/main.cpp

+43-3
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ int main(int argc, char ** argv) {
209209
fprintf(stderr, "Input prefix: '%s'\n", params.input_prefix.c_str());
210210
}
211211
}
212+
213+
if (params.stop_keywords.size()) {
214+
for (auto stop_keyword : params.stop_keywords) {
215+
fprintf(stderr, "Stop keyword: '%s'\n", stop_keyword.c_str());
216+
}
217+
}
218+
212219
fprintf(stderr, "sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n",
213220
params.temp, params.top_k, params.top_p, params.repeat_last_n, params.repeat_penalty);
214221
fprintf(stderr, "generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
@@ -344,13 +351,28 @@ int main(int argc, char ** argv) {
344351
// check if we should prompt the user for more
345352
if (params.interactive && (int) embd_inp.size() <= n_consumed) {
346353

347-
// check for reverse prompt
348-
if (params.antiprompt.size()) {
349-
std::string last_output;
354+
std::string last_output;
355+
if (params.antiprompt.size() || params.stop_keywords.size()) {
350356
for (auto id : last_n_tokens) {
351357
last_output += llama_token_to_str(ctx, id);
352358
}
359+
}
360+
361+
// Check for stop keywords, a configurable alternative to the end-of-text token
362+
// This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
363+
bool stop = false;
364+
for (std::string stop_keyword : params.stop_keywords) {
365+
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
366+
stop = true;
367+
break;
368+
}
369+
}
370+
if (stop) {
371+
break;
372+
}
353373

374+
// check for reverse prompt
375+
if (params.antiprompt.size()) {
354376
is_antiprompt = false;
355377
// Check if each of the reverse prompts appears at the end of the output.
356378
for (std::string & antiprompt : params.antiprompt) {
@@ -430,6 +452,24 @@ int main(int argc, char ** argv) {
430452
}
431453
}
432454

455+
// Check for stop keywords, a configurable alternative to the end-of-text token
456+
if (!params.interactive && params.stop_keywords.size() && !is_interacting) {
457+
std::string last_output;
458+
for (auto id : last_n_tokens) {
459+
last_output += llama_token_to_str(ctx, id);
460+
}
461+
bool stop = false;
462+
for (std::string stop_keyword : params.stop_keywords) {
463+
if (last_output.find(stop_keyword.c_str(), last_output.length() - stop_keyword.length(), stop_keyword.length()) != std::string::npos) {
464+
stop = true;
465+
break;
466+
}
467+
}
468+
if (stop) {
469+
break;
470+
}
471+
}
472+
433473
// end of text token
434474
if (!embd.empty() && embd.back() == llama_token_eos()) {
435475
if (params.instruct) {

0 commit comments

Comments
 (0)