From ced65a56b007d12145418e456d2e1073716857c5 Mon Sep 17 00:00:00 2001 From: randxie Date: Tue, 27 Jun 2023 22:28:31 +0800 Subject: [PATCH 1/3] convert checks in llama_load_session_file to throw and handle them --- llama.cpp | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/llama.cpp b/llama.cpp index 2482bdd18d2e7..adc2f11e385e9 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3336,7 +3336,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +void llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -3345,16 +3345,14 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi const uint32_t version = file.read_u32(); if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); - return false; + throw std::runtime_error(format("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version)); } llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); if (session_hparams != ctx->model.hparams) { - fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__); - return false; + throw std::runtime_error(format("%s : model hparams didn't match from session file!\n", __func__)); } } @@ -3363,8 +3361,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi const uint32_t n_token_count = file.read_u32(); if (n_token_count > n_token_capacity) { - fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); - return false; + throw std::runtime_error(format("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity)); } file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); @@ -3377,8 +3374,7 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi const size_t n_state_size_max = llama_get_state_size(ctx); if (n_state_size_cur > n_state_size_max) { - fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); - return false; + throw std::runtime_error(format("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur)); } std::vector state_data(n_state_size_max); @@ -3386,8 +3382,16 @@ bool llama_load_session_file(struct llama_context * ctx, const char * path_sessi llama_set_state_data(ctx, state_data.data()); } +} - return true; +bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { + try { + llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); + return true; + } catch (const std::exception & err) { + fprintf(stderr, "error loading session file: %s\n", err.what()); + return false; + } } bool llama_save_session_file(struct llama_context * ctx, const char * path_session, const llama_token * tokens, size_t n_token_count) { From 67be3fc743c674e369e7da3a20053dd9754ef3d8 Mon Sep 17 00:00:00 2001 From: randxie Date: Wed, 28 Jun 2023 00:40:29 +0800 Subject: [PATCH 2/3] make llama_load_session_file_internal static --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index adc2f11e385e9..feca4d9d27abd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3336,7 +3336,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -void llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static void llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks From fcb0a77b13de647e63ab1c2cf3614a70bb708f43 Mon Sep 17 00:00:00 2001 From: randxie Date: Thu, 29 Jun 2023 07:49:31 +0800 Subject: [PATCH 3/3] address feedbacks to avoid using exceptions --- llama.cpp | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/llama.cpp b/llama.cpp index feca4d9d27abd..5341516c27f65 100644 --- a/llama.cpp +++ b/llama.cpp @@ -3336,7 +3336,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { return nread; } -static void llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { +static bool llama_load_session_file_internal(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { llama_file file(path_session, "rb"); // sanity checks @@ -3345,14 +3345,16 @@ static void llama_load_session_file_internal(struct llama_context * ctx, const c const uint32_t version = file.read_u32(); if (magic != LLAMA_SESSION_MAGIC || version != LLAMA_SESSION_VERSION) { - throw std::runtime_error(format("%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version)); + fprintf(stderr, "%s : unknown (magic, version) for session file: %08x, %08x\n", __func__, magic, version); + return false; } llama_hparams session_hparams; file.read_raw(&session_hparams, sizeof(llama_hparams)); if (session_hparams != ctx->model.hparams) { - throw std::runtime_error(format("%s : model hparams didn't match from session file!\n", __func__)); + fprintf(stderr, "%s : model hparams didn't match from session file!\n", __func__); + return false; } } @@ -3361,7 +3363,8 @@ static void llama_load_session_file_internal(struct llama_context * ctx, const c const uint32_t n_token_count = file.read_u32(); if (n_token_count > n_token_capacity) { - throw std::runtime_error(format("%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity)); + fprintf(stderr, "%s : token count in session file exceeded capacity! %u > %zu\n", __func__, n_token_count, n_token_capacity); + return false; } file.read_raw(tokens_out, sizeof(llama_token) * n_token_count); @@ -3374,7 +3377,8 @@ static void llama_load_session_file_internal(struct llama_context * ctx, const c const size_t n_state_size_max = llama_get_state_size(ctx); if (n_state_size_cur > n_state_size_max) { - throw std::runtime_error(format("%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur)); + fprintf(stderr, "%s : the state size in session file is too big! max %zu, got %zu\n", __func__, n_state_size_max, n_state_size_cur); + return false; } std::vector state_data(n_state_size_max); @@ -3386,8 +3390,7 @@ static void llama_load_session_file_internal(struct llama_context * ctx, const c bool llama_load_session_file(struct llama_context * ctx, const char * path_session, llama_token * tokens_out, size_t n_token_capacity, size_t * n_token_count_out) { try { - llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); - return true; + return llama_load_session_file_internal(ctx, path_session, tokens_out, n_token_capacity, n_token_count_out); } catch (const std::exception & err) { fprintf(stderr, "error loading session file: %s\n", err.what()); return false;