From 8ed3c3fe2a2fef759069cf13d9bd2eb115cc0bc4 Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 14 Apr 2023 02:38:10 +0200 Subject: [PATCH 1/5] reserve correct size for logits --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 33ee4fbb59474..21e4d73fb41de 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1787,7 +1787,7 @@ struct llama_context * llama_init_from_file( if (params.logits_all) { ctx->logits.reserve(hparams.n_ctx*hparams.n_vocab); } else { - ctx->logits.reserve(hparams.n_ctx); + ctx->logits.reserve(hparams.n_vocab); } if (params.embedding){ From 8288b36749c23be61c7bb3f567c86e95c27c6d0e Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 14 Apr 2023 03:51:34 +0200 Subject: [PATCH 2/5] add functions to get and set the whole llama state: including rng, logits, embedding and kv_cache --- llama.cpp | 124 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ llama.h | 12 ++++++ 2 files changed, 136 insertions(+) diff --git a/llama.cpp b/llama.cpp index 21e4d73fb41de..f047252eedb75 100644 --- a/llama.cpp +++ b/llama.cpp @@ -27,6 +27,7 @@ #include #include #include +#include #define LLAMA_USE_SCRATCH #define LLAMA_MAX_SCRATCH_BUFFERS 16 @@ -2248,3 +2249,126 @@ const char * llama_print_system_info(void) { std::vector>& llama_internal_get_tensor_map(struct llama_context * ctx) { return ctx->model.tensors_by_name; } + +// Returns the size of the state +size_t llama_get_state_size(struct llama_context * ctx) { + const size_t s_bool = sizeof(int32_t); + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // for reference, std::mt19937(1337) serializes to 6701 bytes. + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = 64*1024; + const size_t s_logits_capacity = sizeof(size_t); + const size_t s_logits_size = sizeof(size_t); + const size_t s_logits = ctx->logits.capacity() * sizeof(float); + const size_t s_embedding_size = sizeof(size_t); + const size_t s_embedding = ctx->embedding.size() * sizeof(float); + const size_t s_kv_size = sizeof(size_t); + const size_t s_kv_ntok = sizeof(int); + const size_t s_kv = llama_get_kv_cache_size(ctx); + const size_t s_total = ( + + s_rng_size + + s_rng + + s_logits_capacity + + s_logits_size + + s_logits + + s_embedding_size + + s_embedding + + s_kv_size + + s_kv_ntok + + s_kv + ); + return s_total; +} + +// Copies the state to the specified destination address +size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { + std::stringstream rng_ss; + rng_ss << ctx->rng; + const size_t rng_size = rng_ss.str().size(); + char rng_buf[64*1024]; + memset(&rng_buf[0], 0, 64*1024); + memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); + const int32_t has_evaluated_once = ctx->has_evaluated_once ? 1 : 0; + const int32_t logits_all = ctx->logits_all ? 1 : 0; + const size_t logits_capacity = ctx->logits.capacity(); + const size_t logits_size = ctx->logits.size(); + const size_t embedding_size = ctx->embedding.size(); + const size_t kv_size = llama_get_kv_cache_size(ctx); + const int kv_ntok = llama_get_kv_cache_token_count(ctx); + + uint8_t * out = dest; + memcpy(out, &rng_size, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &rng_buf[0], 64*1024); out += 64*1024; + memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); + if (logits_size) { + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); + } + out += logits_capacity * sizeof(float); + memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); + if (embedding_size) { + memcpy(out, ctx->embedding.data(), embedding_size * sizeof(float)); out += embedding_size * sizeof(float); + } + memcpy(out, &kv_size, sizeof(size_t)); out += sizeof(size_t); + memcpy(out, &kv_ntok, sizeof(int)); out += sizeof(int); + if (kv_size) { + memcpy(out, llama_get_kv_cache(ctx), kv_size); out += kv_size; + } + const size_t written = out - dest; + const size_t expected = llama_get_state_size(ctx); + LLAMA_ASSERT(written == expected); + return written; +} + +// Copies the state to the specified destination address +size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { + size_t rng_size; + char rng_buf[64*1024]; + std::stringstream rng_ss; + + const uint8_t * in = src; + memcpy(&rng_size, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&rng_buf[0], in, 64*1024); in += 64*1024; + rng_ss.str(std::string(&rng_buf[0], rng_size)); + rng_ss >> ctx->rng; + LLAMA_ASSERT(rng_ss.fail() == false); + + int32_t has_evaluated_once; + int32_t logits_all; + size_t logits_capacity; + size_t logits_size; + size_t embedding_size; + size_t kv_size; + int kv_ntok; + + memcpy(&logits_capacity, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&logits_size, in, sizeof(size_t)); in += sizeof(size_t); + LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); + if (logits_size) { + ctx->logits.resize(logits_size); + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + } + in += logits_capacity * sizeof(float); + memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); + LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); + if (embedding_size) { + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); + in += embedding_size * sizeof(float); + } + memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); + memcpy(&kv_ntok, in, sizeof(int)); in += sizeof(int); + if (kv_size) { + LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); + void * k_data = ctx->model.kv_self.k->data; // remember data pointers + void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); + ctx->model.kv_self.k->data = k_data; // restore correct data pointers + ctx->model.kv_self.v->data = v_data; + in += kv_size; + } + ctx->model.kv_self.n = kv_ntok; + const size_t nread = in - src; + const size_t expected = llama_get_state_size(ctx); + LLAMA_ASSERT(nread == expected); + return nread; +} diff --git a/llama.h b/llama.h index e95ff73b8df1d..85aa4f2b612ab 100644 --- a/llama.h +++ b/llama.h @@ -129,6 +129,18 @@ extern "C" { size_t n_size, int n_token_count); + // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) + LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); + + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. + // Returns the number of bytes copied + LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); + + // Set the state reading from the specified address + // Returns the number of bytes read + LLAMA_API size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src); + // Run the llama inference to obtain the logits and probabilities for the next token. // tokens + n_tokens is the provided batch of new tokens to process // n_past is the number of tokens to use from previous eval calls From 9d26580b849d33e8a4088a689f510b00c9f20170 Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 21 Apr 2023 17:23:53 +0200 Subject: [PATCH 3/5] remove unused variables --- llama.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/llama.cpp b/llama.cpp index f047252eedb75..32df8ac883778 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2288,8 +2288,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { char rng_buf[64*1024]; memset(&rng_buf[0], 0, 64*1024); memcpy(&rng_buf[0], rng_ss.str().data(), rng_ss.str().size()); - const int32_t has_evaluated_once = ctx->has_evaluated_once ? 1 : 0; - const int32_t logits_all = ctx->logits_all ? 1 : 0; const size_t logits_capacity = ctx->logits.capacity(); const size_t logits_size = ctx->logits.size(); const size_t embedding_size = ctx->embedding.size(); @@ -2333,8 +2331,6 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { rng_ss >> ctx->rng; LLAMA_ASSERT(rng_ss.fail() == false); - int32_t has_evaluated_once; - int32_t logits_all; size_t logits_capacity; size_t logits_size; size_t embedding_size; From 1c51e1f324aa3519f1801f049dc3e60ad9d812bc Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 21 Apr 2023 17:32:39 +0200 Subject: [PATCH 4/5] remove trailing whitespace --- llama.cpp | 14 +++++++------- llama.h | 4 ++-- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/llama.cpp b/llama.cpp index 32df8ac883778..6bc5199d462c2 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2253,10 +2253,10 @@ std::vector>& llama_internal_get_te // Returns the size of the state size_t llama_get_state_size(struct llama_context * ctx) { const size_t s_bool = sizeof(int32_t); - // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. + // we don't know size of rng until we actually serialize it. so reserve more than enough memory for its serialized state. // for reference, std::mt19937(1337) serializes to 6701 bytes. - const size_t s_rng_size = sizeof(size_t); - const size_t s_rng = 64*1024; + const size_t s_rng_size = sizeof(size_t); + const size_t s_rng = 64*1024; const size_t s_logits_capacity = sizeof(size_t); const size_t s_logits_size = sizeof(size_t); const size_t s_logits = ctx->logits.capacity() * sizeof(float); @@ -2300,7 +2300,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { memcpy(out, &logits_capacity, sizeof(size_t)); out += sizeof(size_t); memcpy(out, &logits_size, sizeof(size_t)); out += sizeof(size_t); if (logits_size) { - memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); + memcpy(out, ctx->logits.data(), logits_size * sizeof(float)); } out += logits_capacity * sizeof(float); memcpy(out, &embedding_size, sizeof(size_t)); out += sizeof(size_t); @@ -2342,13 +2342,13 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { LLAMA_ASSERT(ctx->logits.capacity() == logits_capacity); if (logits_size) { ctx->logits.resize(logits_size); - memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); + memcpy(ctx->logits.data(), in, logits_size * sizeof(float)); } in += logits_capacity * sizeof(float); memcpy(&embedding_size, in, sizeof(size_t)); in += sizeof(size_t); LLAMA_ASSERT(ctx->embedding.capacity() == embedding_size); if (embedding_size) { - memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); + memcpy(ctx->embedding.data(), in, embedding_size * sizeof(float)); in += embedding_size * sizeof(float); } memcpy(&kv_size, in, sizeof(size_t)); in += sizeof(size_t); @@ -2357,7 +2357,7 @@ size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { LLAMA_ASSERT(ctx->model.kv_self.buf.size == kv_size); void * k_data = ctx->model.kv_self.k->data; // remember data pointers void * v_data = ctx->model.kv_self.v->data; // because their value is stored in buf and overwritten by memcpy - memcpy(ctx->model.kv_self.buf.addr, in, kv_size); + memcpy(ctx->model.kv_self.buf.addr, in, kv_size); ctx->model.kv_self.k->data = k_data; // restore correct data pointers ctx->model.kv_self.v->data = v_data; in += kv_size; diff --git a/llama.h b/llama.h index 85aa4f2b612ab..f68a0cb403b21 100644 --- a/llama.h +++ b/llama.h @@ -132,8 +132,8 @@ extern "C" { // Returns the size in bytes of the state (rng, logits, embedding and kv_cache) LLAMA_API size_t llama_get_state_size(struct llama_context * ctx); - // Copies the state to the specified destination address. - // Destination needs to have allocated enough memory. + // Copies the state to the specified destination address. + // Destination needs to have allocated enough memory. // Returns the number of bytes copied LLAMA_API size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest); From 456aedc4615ed5eff33809b2aa1d5f41329e443b Mon Sep 17 00:00:00 2001 From: xaedes Date: Fri, 21 Apr 2023 17:49:52 +0200 Subject: [PATCH 5/5] fix comment --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 6bc5199d462c2..9a773d201de95 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2318,7 +2318,7 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dest) { return written; } -// Copies the state to the specified destination address +// Sets the state reading from the specified source address size_t llama_set_state_data(struct llama_context * ctx, const uint8_t * src) { size_t rng_size; char rng_buf[64*1024];