@@ -829,6 +829,29 @@ void sigint_handler(int signo) {
829829}
830830#endif
831831
832+
833+ std::string escapeString (std::string stdstr) {
834+ const char * str = stdstr.c_str ();
835+ std::string escapedStr;
836+ for (const char * c = str; *c != ' \0 ' ; ++c) {
837+ switch (*c) {
838+ case ' \a ' : escapedStr += " \\ a" ; break ;
839+ case ' \b ' : escapedStr += " \\ b" ; break ;
840+ case ' \f ' : escapedStr += " \\ f" ; break ;
841+ case ' \n ' : escapedStr += " \\ n" ; break ;
842+ case ' \r ' : escapedStr += " \\ r" ; break ;
843+ case ' \t ' : escapedStr += " \\ t" ; break ;
844+ case ' \v ' : escapedStr += " \\ v" ; break ;
845+ case ' \\ ' : escapedStr += " \\\\ " ; break ;
846+ case ' \" ' : escapedStr += " \\\" " ; break ;
847+ case ' \' ' : escapedStr += " \\\' " ; break ;
848+ default : escapedStr += *c; break ;
849+ }
850+ }
851+ // std::cout << "test string" << escapedStr << std::endl;
852+ return escapedStr;
853+ }
854+
832855int llama_main (
833856 gpt_params params,
834857 llama_vocab vocab,
@@ -842,8 +865,12 @@ int llama_main(
842865 if (params.seed < 0 ) {
843866 params.seed = time (NULL );
844867 }
845-
846- fprintf (errstream, " %s: seed = %d\n " , __func__, params.seed );
868+ if (params.protocol_mode ) {
869+ fprintf (outstream, " %s" , " HELO\n " );
870+ fprintf (outstream, " KV seed=%d\n " , params.seed );
871+ } else {
872+ fprintf (errstream, " %s: seed = %d\n " , __func__, params.seed );
873+ }
847874
848875 std::mt19937 rng (params.seed );
849876 if (params.random_prompt ) {
@@ -891,13 +918,24 @@ int llama_main(
891918 params.interactive = true ;
892919 }
893920
894- fprintf (errstream, " \n " );
895- fprintf (errstream, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
896- fprintf (errstream, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
921+ if (params.protocol_mode ) {
922+ fprintf (outstream, " PROMPT %s\n " , escapeString (params.prompt ).c_str ());
923+ fprintf (outstream, " KV prompt_tokens=%zu\n " ,embd_inp.size ());
924+ } else {
925+ fprintf (errstream, " \n " );
926+ fprintf (errstream, " %s: prompt: '%s'\n " , __func__, params.prompt .c_str ());
927+ fprintf (errstream, " %s: number of tokens in prompt = %zu\n " , __func__, embd_inp.size ());
928+ }
897929 for (int i = 0 ; i < (int ) embd_inp.size (); i++) {
898- fprintf (errstream, " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
930+ if (params.protocol_mode ) {
931+ fprintf (outstream, " DEBUG %d -> '%s'\n " , embd_inp[i], escapeString (vocab.id_to_token .at (embd_inp[i])).c_str ());
932+ } else {
933+ fprintf (errstream, " %6d -> '%s'\n " , embd_inp[i], vocab.id_to_token .at (embd_inp[i]).c_str ());
934+ }
935+ }
936+ if (!params.protocol_mode ) {
937+ fprintf (errstream, " \n " );
899938 }
900- fprintf (errstream, " \n " );
901939 if (params.interactive ) {
902940#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
903941 struct sigaction sigint_action;
@@ -909,16 +947,32 @@ int llama_main(
909947 signal (SIGINT, sigint_handler);
910948#endif
911949
912- fprintf (errstream, " %s: interactive mode on.\n " , __func__);
950+ if (params.protocol_mode ) {
951+ fprintf (outstream, " KV interactive_mode=true\n " );
952+ } else {
953+ fprintf (errstream, " %s: interactive mode on.\n " , __func__);
954+ }
913955
914956 if (params.antiprompt .size ()) {
915957 for (auto antiprompt : params.antiprompt ) {
916- fprintf (errstream, " Reverse prompt: '%s'\n " , antiprompt.c_str ());
958+ if (params.protocol_mode ) {
959+ fprintf (outstream, " KV reverse_prompt=\" %s\"\n " , escapeString (antiprompt).c_str ());
960+ } else {
961+ fprintf (errstream, " Reverse prompt: '%s'\n " , antiprompt.c_str ());
962+ }
917963 }
918964 }
919965 }
920- fprintf (errstream, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
921- fprintf (errstream, " \n\n " );
966+ if (params.protocol_mode ) {
967+ fprintf (errstream, " KV temp=%f\n " , params.temp );
968+ fprintf (errstream, " KV top_k=%d\n " , params.top_k );
969+ fprintf (errstream, " KV top_p=%f\n " , params.top_p );
970+ fprintf (errstream, " KV repeat_last_n=%i\n " , params.repeat_last_n );
971+ fprintf (errstream, " KV repeat_penalty=%f\n " , params.repeat_penalty );
972+ } else {
973+ fprintf (errstream, " sampling parameters: temp = %f, top_k = %d, top_p = %f, repeat_last_n = %i, repeat_penalty = %f\n " , params.temp , params.top_k , params.top_p , params.repeat_last_n , params.repeat_penalty );
974+ fprintf (errstream, " \n\n " );
975+ }
922976
923977 std::vector<llama_vocab::id> embd;
924978
@@ -927,12 +981,14 @@ int llama_main(
927981 std::fill (last_n_tokens.begin (), last_n_tokens.end (), 0 );
928982
929983 if (params.interactive ) {
930- fprintf (errstream, " == Running in interactive mode. ==\n "
984+ if (!params.protocol_mode ) {
985+ fprintf (errstream, " == Running in interactive mode. ==\n "
931986#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32)
932- " - Press Ctrl+C to interject at any time.\n "
987+ " - Press Ctrl+C to interject at any time.\n "
933988#endif
934- " - Press Return to return control to LLaMa.\n "
935- " - If you want to submit another line, end your input in '\\ '.\n\n " );
989+ " - Press Return to return control to LLaMa.\n "
990+ " - If you want to submit another line, end your input in '\\ '.\n\n " );
991+ }
936992 is_interacting = true ;
937993 }
938994
@@ -955,12 +1011,19 @@ int llama_main(
9551011 }
9561012
9571013 while (remaining_tokens > 0 || params.interactive ) {
1014+ if (params.protocol_mode && !params.interactive ) {
1015+ fprintf (outstream, " KV remaining_tokens=%d\n " , remaining_tokens);
1016+ }
9581017 // predict
9591018 if (embd.size () > 0 ) {
9601019 const int64_t t_start_us = ggml_time_us ();
9611020
9621021 if (!llama_eval (model, params.n_threads , n_past, embd, logits, mem_per_token)) {
963- fprintf (errstream, " Failed to predict\n " );
1022+ if (params.protocol_mode ) {
1023+ fprintf (outstream, " FATAL Error: Failed to predict\n " );
1024+ } else {
1025+ fprintf (errstream, " Failed to predict\n " );
1026+ }
9641027 return 1 ;
9651028 }
9661029
@@ -1020,8 +1083,16 @@ int llama_main(
10201083
10211084 // display text
10221085 if (!input_noecho) {
1086+ if (params.protocol_mode ) {
1087+ fprintf (outstream, " OUTPUT " );
1088+ }
10231089 for (auto id : embd) {
1024- fprintf (outstream, " %s" , vocab.id_to_token [id].c_str ());
1090+ fprintf (outstream, " %s" , params.protocol_mode ?
1091+ escapeString (vocab.id_to_token [id]).c_str () :
1092+ vocab.id_to_token [id].c_str ());
1093+ }
1094+ if (params.protocol_mode ) {
1095+ fprintf (outstream, " \n " );
10251096 }
10261097 fflush (outstream);
10271098 }
@@ -1047,11 +1118,17 @@ int llama_main(
10471118 }
10481119 }
10491120 if (is_interacting) {
1121+ if (params.protocol_mode ) {
1122+ fprintf (outstream, " KV awaiting_prompt=true\n " );
1123+ fflush (outstream);
1124+ }
10501125 if (params.instruct ) {
10511126 input_consumed = embd_inp.size ();
10521127 embd_inp.insert (embd_inp.end (), inp_pfx.begin (), inp_pfx.end ());
10531128
1054- fprintf (outstream, " \n > " );
1129+ if (!params.protocol_mode ) {
1130+ fprintf (outstream, " \n > " );
1131+ }
10551132 }
10561133
10571134 // currently being interactive
@@ -1068,6 +1145,7 @@ int llama_main(
10681145 }
10691146 buffer += line + ' \n ' ; // Append the line to the result
10701147 } while (another_line);
1148+ fprintf (outstream, " PROMPT %s\n " , escapeString (line).c_str ());
10711149 if (params.use_color ) fprintf (outstream, ANSI_COLOR_RESET);
10721150
10731151 std::vector<llama_vocab::id> line_inp = ::llama_tokenize (vocab, buffer, false );
@@ -1080,6 +1158,10 @@ int llama_main(
10801158 remaining_tokens -= line_inp.size ();
10811159
10821160 input_noecho = true ; // do not echo this again
1161+ if (params.protocol_mode ) {
1162+ fprintf (outstream, " KV awaiting_prompt=false\n " );
1163+ fflush (outstream);
1164+ }
10831165 }
10841166 is_interacting = false ;
10851167 }
@@ -1089,7 +1171,13 @@ int llama_main(
10891171 if (params.interactive ) {
10901172 is_interacting = true ;
10911173 } else {
1092- fprintf (errstream, " [end of text]\n " );
1174+ if (params.protocol_mode ) {
1175+ fprintf (outstream, " END_OF_TEXT\n " );
1176+ fflush (outstream);
1177+ } else {
1178+ fprintf (errstream, " [end of text]\n " );
1179+ fflush (errstream);
1180+ }
10931181 break ;
10941182 }
10951183 }
0 commit comments