@@ -209,6 +209,13 @@ int main(int argc, char ** argv) {
209
209
fprintf (stderr, " Input prefix: '%s'\n " , params.input_prefix .c_str ());
210
210
}
211
211
}
212
+
213
+ if (params.stop_keywords .size ()) {
214
+ for (auto stop_keyword : params.stop_keywords ) {
215
+ fprintf (stderr, " Stop keyword: '%s'\n " , stop_keyword.c_str ());
216
+ }
217
+ }
218
+
212
219
fprintf (stderr, " sampling: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " ,
213
220
params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
214
221
fprintf (stderr, " generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n " , n_ctx, params.n_batch , params.n_predict , params.n_keep );
@@ -344,13 +351,28 @@ int main(int argc, char ** argv) {
344
351
// check if we should prompt the user for more
345
352
if (params.interactive && (int ) embd_inp.size () <= n_consumed) {
346
353
347
- // check for reverse prompt
348
- if (params.antiprompt .size ()) {
349
- std::string last_output;
354
+ std::string last_output;
355
+ if (params.antiprompt .size () || params.stop_keywords .size ()) {
350
356
for (auto id : last_n_tokens) {
351
357
last_output += llama_token_to_str (ctx, id);
352
358
}
359
+ }
360
+
361
+ // Check for stop keywords, a configurable alternative to the end-of-text token
362
+ // This should stop also the interactive mode, useful to stop interactive mode without SIGTERM
363
+ bool stop = false ;
364
+ for (std::string stop_keyword : params.stop_keywords ) {
365
+ if (last_output.find (stop_keyword.c_str (), last_output.length () - stop_keyword.length (), stop_keyword.length ()) != std::string::npos) {
366
+ stop = true ;
367
+ break ;
368
+ }
369
+ }
370
+ if (stop) {
371
+ break ;
372
+ }
353
373
374
+ // check for reverse prompt
375
+ if (params.antiprompt .size ()) {
354
376
is_antiprompt = false ;
355
377
// Check if each of the reverse prompts appears at the end of the output.
356
378
for (std::string & antiprompt : params.antiprompt ) {
@@ -430,6 +452,24 @@ int main(int argc, char ** argv) {
430
452
}
431
453
}
432
454
455
+ // Check for stop keywords, a configurable alternative to the end-of-text token
456
+ if (!params.interactive && params.stop_keywords .size () && !is_interacting) {
457
+ std::string last_output;
458
+ for (auto id : last_n_tokens) {
459
+ last_output += llama_token_to_str (ctx, id);
460
+ }
461
+ bool stop = false ;
462
+ for (std::string stop_keyword : params.stop_keywords ) {
463
+ if (last_output.find (stop_keyword.c_str (), last_output.length () - stop_keyword.length (), stop_keyword.length ()) != std::string::npos) {
464
+ stop = true ;
465
+ break ;
466
+ }
467
+ }
468
+ if (stop) {
469
+ break ;
470
+ }
471
+ }
472
+
433
473
// end of text token
434
474
if (!embd.empty () && embd.back () == llama_token_eos ()) {
435
475
if (params.instruct ) {
0 commit comments