From df197cdf606ab5ef310dd338b08694114c2256bd Mon Sep 17 00:00:00 2001 From: Tindell Lockett Date: Tue, 21 Mar 2023 12:24:42 -0400 Subject: [PATCH 1/3] Refactor code to use default parameters and add support for file input and non-interactive mode. --- chat.cpp | 33 ++++++++++++++++----------------- utils.cpp | 5 +++++ utils.h | 13 +++++++------ 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/chat.cpp b/chat.cpp index 38b39771ad982..aff0235e5ca3e 100644 --- a/chat.cpp +++ b/chat.cpp @@ -318,7 +318,7 @@ bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab fin.close(); std::vector tmp; - + for (int i = 0; i < n_parts; ++i) { const int part_id = i; //const int part_id = n_parts - i - 1; @@ -797,14 +797,6 @@ int main(int argc, char ** argv) { gpt_params params; - params.temp = 0.1f; - params.top_p = 0.95f; - params.n_ctx = 2048; - params.interactive = true; - params.interactive_start = true; - params.use_color = true; - params.model = "ggml-alpaca-7b-q4.bin"; - if (gpt_params_parse(argc, argv, params) == false) { return 1; } @@ -856,13 +848,27 @@ int main(int argc, char ** argv) { // Add a space in front of the first character to match OG llama tokenizer behavior // params.prompt.insert(0, 1, ' '); // tokenize the prompt - std::vector embd_inp;// = ::llama_tokenize(vocab, params.prompt, true); + std::vector embd_inp; // params.n_predict = std::min(params.n_predict, model.hparams.n_ctx - (int) embd_inp.size()); // // tokenize the reverse prompt // std::vector antiprompt_inp = ::llama_tokenize(vocab, params.antiprompt, false); + + std::vector instruct_inp = ::llama_tokenize(vocab, " Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", true); + std::vector prompt_inp = ::llama_tokenize(vocab, "### Instruction:\n\n", true); + std::vector response_inp = ::llama_tokenize(vocab, "### Response:\n\n", false); + embd_inp.insert(embd_inp.end(), instruct_inp.begin(), instruct_inp.end()); + + if(!params.prompt.empty()) { + std::vector param_inp = ::llama_tokenize(vocab, params.prompt, true); + embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end()); + embd_inp.insert(embd_inp.end(), param_inp.begin(), param_inp.end()); + embd_inp.insert(embd_inp.end(), response_inp.begin(), response_inp.end()); + + } + // fprintf(stderr, "\n"); // fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str()); // fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size()); @@ -871,13 +877,6 @@ int main(int argc, char ** argv) { // } // fprintf(stderr, "\n"); - std::vector instruct_inp = ::llama_tokenize(vocab, " Below is an instruction that describes a task. Write a response that appropriately completes the request.\n\n", true); - std::vector prompt_inp = ::llama_tokenize(vocab, "### Instruction:\n\n", true); - std::vector response_inp = ::llama_tokenize(vocab, "### Response:\n\n", false); - - embd_inp.insert(embd_inp.end(), instruct_inp.begin(), instruct_inp.end()); - - if (params.interactive) { #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) struct sigaction sigint_action; diff --git a/utils.cpp b/utils.cpp index d739b5d489239..39d091c7942b9 100644 --- a/utils.cpp +++ b/utils.cpp @@ -24,9 +24,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-t" || arg == "--threads") { params.n_threads = std::stoi(argv[++i]); } else if (arg == "-p" || arg == "--prompt") { + params.interactive = false; + params.interactive_start = false; params.prompt = argv[++i]; } else if (arg == "-f" || arg == "--file") { + params.interactive = false; + params.interactive_start = false; + std::ifstream file(argv[++i]); std::copy(std::istreambuf_iterator(file), diff --git a/utils.h b/utils.h index 021120b0513c7..2a843371a35e0 100644 --- a/utils.h +++ b/utils.h @@ -12,28 +12,29 @@ // CLI argument parsing // +// The default parameters struct gpt_params { int32_t seed = -1; // RNG seed int32_t n_threads = std::min(4, (int32_t) std::thread::hardware_concurrency()); int32_t n_predict = 128; // new tokens to predict int32_t repeat_last_n = 64; // last n tokens to penalize - int32_t n_ctx = 512; //context size + int32_t n_ctx = 2048; //context size // sampling parameters int32_t top_k = 40; float top_p = 0.95f; - float temp = 0.80f; + float temp = 0.10f; float repeat_penalty = 1.30f; int32_t n_batch = 8; // batch size for prompt processing - std::string model = "models/lamma-7B/ggml-model.bin"; // model path + std::string model = "ggml-alpaca-7b-q4.bin"; // model path std::string prompt; - bool use_color = false; // use color to distinguish generations and inputs + bool use_color = true; // use color to distinguish generations and inputs - bool interactive = false; // interactive mode - bool interactive_start = false; // reverse prompt immediately + bool interactive = true; // interactive mode + bool interactive_start = true; // reverse prompt immediately std::string antiprompt = ""; // string upon seeing which more user input is prompted }; From f3a80f482ebc790cdb4b8bd5878ffaeddeccf34d Mon Sep 17 00:00:00 2001 From: Tindell Lockett Date: Tue, 21 Mar 2023 12:40:12 -0400 Subject: [PATCH 2/3] When not in interactive mode, end after all text is generated --- chat.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/chat.cpp b/chat.cpp index aff0235e5ca3e..9e702c68c1a7e 100644 --- a/chat.cpp +++ b/chat.cpp @@ -1075,9 +1075,14 @@ int main(int argc, char ** argv) { // end of text token if (embd.back() == 2) { - // fprintf(stderr, " [end of text]\n"); - is_interacting = true; - continue; + if (params.interactive) { + is_interacting = true; + continue; + } else { + printf("\n"); + fprintf(stderr, " [end of text]\n"); + break; + } } } From 222dfdd97bf989518a25ee121ab477a3fb6f74bd Mon Sep 17 00:00:00 2001 From: Tindell Lockett Date: Tue, 21 Mar 2023 12:58:03 -0400 Subject: [PATCH 3/3] disabled color in non-interactive mode. --- chat.cpp | 1 - utils.cpp | 3 +++ 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/chat.cpp b/chat.cpp index 9e702c68c1a7e..20759315dc6d9 100644 --- a/chat.cpp +++ b/chat.cpp @@ -866,7 +866,6 @@ int main(int argc, char ** argv) { embd_inp.insert(embd_inp.end(), prompt_inp.begin(), prompt_inp.end()); embd_inp.insert(embd_inp.end(), param_inp.begin(), param_inp.end()); embd_inp.insert(embd_inp.end(), response_inp.begin(), response_inp.end()); - } // fprintf(stderr, "\n"); diff --git a/utils.cpp b/utils.cpp index 39d091c7942b9..420fc26374307 100644 --- a/utils.cpp +++ b/utils.cpp @@ -26,11 +26,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) { } else if (arg == "-p" || arg == "--prompt") { params.interactive = false; params.interactive_start = false; + params.use_color = false; + params.prompt = argv[++i]; } else if (arg == "-f" || arg == "--file") { params.interactive = false; params.interactive_start = false; + params.use_color = false; std::ifstream file(argv[++i]);