@@ -523,7 +523,7 @@ int main(int argc, char ** argv) {
523
523
524
524
const llama_token id = llama_sampling_sample (ctx_sampling, ctx, ctx_guidance);
525
525
526
- llama_sampling_accept (ctx_sampling, ctx, id);
526
+ llama_sampling_accept (ctx_sampling, ctx, id, true );
527
527
528
528
LOG (" last: %s\n " , LOG_TOKENS_TOSTR_PRETTY (ctx, ctx_sampling->prev ).c_str ());
529
529
@@ -541,8 +541,11 @@ int main(int argc, char ** argv) {
541
541
LOG (" embd_inp.size(): %d, n_consumed: %d\n " , (int ) embd_inp.size (), n_consumed);
542
542
while ((int ) embd_inp.size () > n_consumed) {
543
543
embd.push_back (embd_inp[n_consumed]);
544
- ctx_sampling->prev .erase (ctx_sampling->prev .begin ());
545
- ctx_sampling->prev .push_back (embd_inp[n_consumed]);
544
+
545
+ // push the prompt in the sampling context in order to apply repetition penalties later
546
+ // for the prompt, we don't apply grammar rules
547
+ llama_sampling_accept (ctx_sampling, ctx, embd_inp[n_consumed], false );
548
+
546
549
++n_consumed;
547
550
if ((int ) embd.size () >= params.n_batch ) {
548
551
break ;
@@ -574,7 +577,7 @@ int main(int argc, char ** argv) {
574
577
if ((int ) embd_inp.size () <= n_consumed) {
575
578
576
579
// deal with eot token in infill mode
577
- if ((ctx_sampling-> prev . back ( ) == llama_token_eot (ctx) || is_interacting) && params.interactive ){
580
+ if ((llama_sampling_last (ctx_sampling ) == llama_token_eot (ctx) || is_interacting) && params.interactive ){
578
581
if (is_interacting && !params.interactive_first ) {
579
582
// print an eot token
580
583
printf (" %s" , llama_token_to_piece (ctx, llama_token_eot (ctx)).c_str ());
@@ -591,7 +594,7 @@ int main(int argc, char ** argv) {
591
594
buffer += line;
592
595
} while (another_line);
593
596
// check if we got an empty line, if so we use the old input
594
- if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
597
+ if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
595
598
params.input_prefix = buffer;
596
599
}
597
600
buffer.clear ();
@@ -601,7 +604,7 @@ int main(int argc, char ** argv) {
601
604
buffer += line;
602
605
} while (another_line);
603
606
// check if we got an empty line
604
- if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
607
+ if (!buffer.empty () && !(buffer.length () == 1 && buffer[0 ] == ' \n ' )) {
605
608
params.input_suffix = buffer;
606
609
}
607
610
buffer.clear ();
@@ -614,7 +617,7 @@ int main(int argc, char ** argv) {
614
617
process_escapes (params.input_suffix );
615
618
}
616
619
suff_rm_leading_spc = params.escape ;
617
- if (suff_rm_leading_spc && params.input_suffix .find_first_of (" " ) == 0 && params.input_suffix .size () > 1 ) {
620
+ if (suff_rm_leading_spc && params.input_suffix .find_first_of (' ' ) == 0 && params.input_suffix .size () > 1 ) {
618
621
params.input_suffix .erase (0 , 1 );
619
622
suff_rm_leading_spc = false ;
620
623
}
@@ -641,7 +644,7 @@ int main(int argc, char ** argv) {
641
644
is_interacting = false ;
642
645
}
643
646
// deal with end of text token in interactive mode
644
- else if (ctx_sampling-> prev . back ( ) == llama_token_eos (ctx)) {
647
+ else if (llama_sampling_last (ctx_sampling ) == llama_token_eos (ctx)) {
645
648
LOG (" found EOS token\n " );
646
649
647
650
if (params.interactive ) {
0 commit comments