@@ -102,6 +102,9 @@ struct llama_context {
102102 // decode output (2-dimensional array: [n_tokens][n_vocab])
103103 std::vector<float > logits;
104104 bool logits_all = false ;
105+
106+ // work buffer for transformer evaluation
107+ std::vector<uint8_t > buf_eval;
105108};
106109
107110struct llama_context_params llama_context_default_params () {
@@ -627,27 +630,19 @@ static bool llama_eval_internal(
627630 const int n_rot = hparams.n_embd /hparams.n_head ;
628631
629632 auto & mem_per_token = lctx.mem_per_token ;
633+ auto & buf_eval = lctx.buf_eval ;
630634
631- // TODO: fix this hardcoded size
632- static size_t buf_size = 512u *1024 *1024 ;
633- static void * buf = malloc (buf_size);
635+ if (mem_per_token*(n_past + N + 16 ) > buf_eval.size ()) {
636+ const size_t buf_size_new = 1.618 *buf_eval.size ();
634637
635- if (mem_per_token > 0 && mem_per_token*N > buf_size) {
636- const size_t buf_size_new = 1.3 *(mem_per_token*N); // add 30% to account for ggml object overhead
637- // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_size, buf_size_new);
638+ // fprintf(stderr, "\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, buf_eval.size(), buf_size_new);
638639
639- // reallocate
640- buf_size = buf_size_new;
641- buf = realloc (buf, buf_size);
642- if (buf == nullptr ) {
643- fprintf (stderr, " %s: failed to allocate %zu bytes\n " , __func__, buf_size);
644- return false ;
645- }
640+ buf_eval.resize (buf_size_new);
646641 }
647642
648643 struct ggml_init_params params = {
649- /* .mem_size =*/ buf_size ,
650- /* .mem_buffer =*/ buf ,
644+ /* .mem_size =*/ buf_eval. size () ,
645+ /* .mem_buffer =*/ buf_eval. data () ,
651646 };
652647
653648 struct ggml_context * ctx0 = ggml_init (params);
@@ -832,10 +827,11 @@ static bool llama_eval_internal(
832827 memcpy (logits_out.data (), (float *) ggml_get_data (inpL) + (n_vocab*(N-1 )), sizeof (float )*n_vocab);
833828 }
834829
835- if (mem_per_token == 0 ) {
836- mem_per_token = ggml_used_mem (ctx0)/N ;
830+ if (N == 1 ) {
831+ mem_per_token = ggml_used_mem (ctx0)/(n_past + N) ;
837832 }
838- // fprintf(stderr, "used_mem = %zu\n", ggml_used_mem(ctx0));
833+
834+ // fprintf(stderr, "\nused_mem = %zu, %zu MB\n", ggml_used_mem(ctx0), ggml_used_mem(ctx0)/1024/1024);
839835
840836 ggml_free (ctx0);
841837
@@ -1416,6 +1412,8 @@ struct llama_context * llama_init_from_file(
14161412 return nullptr ;
14171413 }
14181414
1415+ ctx->buf_eval .resize (512u *1024u *1024u );
1416+
14191417 return ctx;
14201418}
14211419
0 commit comments