Skip to content

Add ability to cancel model loading #4462

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
Dec 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
9abe2e4
llama : Add ability to cancel model load
crasm Dec 14, 2023
3425e62
llama : Add test for model load cancellation
crasm Dec 14, 2023
4b1f70c
Fix bool return in llama_model_load, remove std::ignore use
crasm Dec 14, 2023
1160de3
Update llama.cpp
ggerganov Dec 17, 2023
32ebd52
Fail test if model file is missing
crasm Dec 17, 2023
cb8a4be
Merge branch 'cancel-model-load' of github.com:crasm/llama.cpp into c…
crasm Dec 17, 2023
2796953
Revert "Fail test if model file is missing"
crasm Dec 17, 2023
068e7c4
Add test-model-load-cancel to Makefile
crasm Dec 18, 2023
fe6a6fb
Revert "Revert "Fail test if model file is missing""
crasm Dec 18, 2023
6bba341
Simplify .gitignore for tests, clang-tidy fixes
crasm Dec 18, 2023
fd9d247
Label all ctest tests
crasm Dec 18, 2023
4b63355
ci : ctest uses -L main
crasm Dec 18, 2023
aed3cf8
Attempt at writing ctest_with_model
crasm Dec 18, 2023
f80ff4d
ci : get ci/run.sh working with test-model-load-cancel
crasm Dec 19, 2023
121b04d
ci : restrict .github/workflows/build.yml ctest to -L main
crasm Dec 19, 2023
1e79625
update requirements.txt
crasm Dec 19, 2023
9809314
Disable test-model-load-cancel in make
crasm Dec 19, 2023
9a056ed
Remove venv before creation
crasm Dec 20, 2023
293d16f
Restructure requirements.txt
crasm Dec 20, 2023
267cfa4
Merge commit 'c50e40016394f124b97ce39da48148b1f6c01833' into cancel-m…
crasm Dec 20, 2023
a0eab1e
Make per-python-script requirements work alone
crasm Dec 20, 2023
ca122dc
Add comment
crasm Dec 20, 2023
ba46057
Merge remote-tracking branch 'upstream/master' into cancel-model-load
crasm Dec 20, 2023
b853df4
Add convert-persimmon-to-gguf.py to new requirements.txt scheme
crasm Dec 20, 2023
c9a6de8
Add check-requirements.sh script and GitHub workflow
crasm Dec 21, 2023
e86b8cd
Remove shellcheck installation step from workflow
crasm Dec 21, 2023
bdfe4ba
Add nocleanup special arg
crasm Dec 21, 2023
6bc7411
Merge remote-tracking branch 'upstream' into cancel-model-load
crasm Dec 21, 2023
e438257
Fix merge
crasm Dec 21, 2023
f607e53
reset to upstream/master
crasm Dec 22, 2023
5f2ee1c
Redo changes for cancelling model load
crasm Dec 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 33 additions & 13 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2372,7 +2372,8 @@ struct llama_model_loader {
}
}

void load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, ggml_backend_buffer_t buf_mmap, llama_mlock * lmlock) const {
// Returns false if cancelled by progress_callback
bool load_all_data(struct ggml_context * ctx, llama_progress_callback progress_callback, void * progress_callback_user_data, ggml_backend_buffer_t buf_mmap, llama_mlock * lmlock) const {
size_t size_data = 0;

for (int i = 0; i < gguf_get_n_tensors(ctx_gguf); i++) {
Expand Down Expand Up @@ -2404,7 +2405,9 @@ struct llama_model_loader {
GGML_ASSERT(cur); // unused tensors should have been caught by load_data already

if (progress_callback) {
progress_callback((float) size_done / size_data, progress_callback_user_data);
if (!progress_callback((float) size_done / size_data, progress_callback_user_data)) {
return false;
}
}

const size_t offs = file_offset(ggml_get_name(cur));
Expand Down Expand Up @@ -2466,8 +2469,11 @@ struct llama_model_loader {
}

if (progress_callback) {
progress_callback(1.0f, progress_callback_user_data);
// Even though the model is done loading, we still honor
// cancellation since we need to free allocations.
return progress_callback(1.0f, progress_callback_user_data);
}
return true;
}
};

Expand Down Expand Up @@ -3044,7 +3050,8 @@ static void llm_load_print_meta(llama_model_loader & ml, llama_model & model) {
if (vocab.linefeed_id != -1) { LLAMA_LOG_INFO( "%s: LF token = %d '%s'\n", __func__, vocab.linefeed_id, vocab.id_to_token[vocab.linefeed_id].text.c_str() ); }
}

static void llm_load_tensors(
// Returns false if cancelled by progress_callback
static bool llm_load_tensors(
llama_model_loader & ml,
llama_model & model,
int n_gpu_layers,
Expand Down Expand Up @@ -3722,16 +3729,20 @@ static void llm_load_tensors(
model.tensors_by_name.emplace_back(ggml_get_name(cur), cur);
}

ml.load_all_data(ctx, progress_callback, progress_callback_user_data, buf_mmap, use_mlock ? &model.mlock_mmap : NULL);
if (!ml.load_all_data(ctx, progress_callback, progress_callback_user_data, buf_mmap, use_mlock ? &model.mlock_mmap : NULL)) {
return false;
}

model.mapping = std::move(ml.mapping);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren do you know if this line will be a problem? Since it doesn't get run if the above returns early

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@slaren do you know if this line will be a problem? Since it doesn't get run if the above returns early

The progress callback should only be called if it loaded successfully, I think. Would be weird to run it with 1.0 if the model load actually failed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipping the mapping move should be fine.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The case I'm trying to avoid is:

  1. The model has fully loaded
  2. The client has received some information that the model load should be cancelled (e.g. on another thread)
  3. llama.cpp calls the progress_callback with 1.0f
  4. Client returns false (independent from the 1.0f value itself)
  5. llama.cpp ignores the callback return value, and returns a valid llama_model *
  6. Now the client has to free the model, even though it did not expect one to be returned


// loading time will be recalculate after the first eval, so
// we take page faults deferred by mmap() into consideration
model.t_load_us = ggml_time_us() - model.t_start_us;
return true;
}

static bool llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
// Returns 0 on success, -1 on error, and -2 on cancellation via llama_progress_callback
static int llama_model_load(const std::string & fname, llama_model & model, const llama_model_params & params) {
try {
llama_model_loader ml(fname, params.use_mmap, params.kv_overrides);

Expand All @@ -3749,19 +3760,21 @@ static bool llama_model_load(const std::string & fname, llama_model & model, con

if (params.vocab_only) {
LLAMA_LOG_INFO("%s: vocab only - skipping tensors\n", __func__);
return true;
return 0;
}

llm_load_tensors(
if (!llm_load_tensors(
ml, model, params.n_gpu_layers, params.main_gpu, params.tensor_split, params.use_mlock,
params.progress_callback, params.progress_callback_user_data
);
)) {
return -2;
}
} catch (const std::exception & err) {
LLAMA_LOG_ERROR("error loading model: %s\n", err.what());
return false;
return -1;
}

return true;
return 0;
}

//
Expand Down Expand Up @@ -9141,11 +9154,18 @@ struct llama_model * llama_load_model_from_file(
LLAMA_LOG_INFO("\n");
}
}
return true;
};
}

if (!llama_model_load(path_model, *model, params)) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
int status = llama_model_load(path_model, *model, params);
GGML_ASSERT(status <= 0);
if (status < 0) {
if (status == -1) {
LLAMA_LOG_ERROR("%s: failed to load model\n", __func__);
} else if (status == -2) {
LLAMA_LOG_INFO("%s: cancelled model load\n", __func__);
}
delete model;
return nullptr;
}
Expand Down
6 changes: 4 additions & 2 deletions llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ extern "C" {
bool sorted;
} llama_token_data_array;

typedef void (*llama_progress_callback)(float progress, void *ctx);
typedef bool (*llama_progress_callback)(float progress, void *ctx);

// Input data for llama_decode
// A llama_batch object can contain input about one or many sequences
Expand Down Expand Up @@ -180,7 +180,9 @@ extern "C" {
int32_t main_gpu; // the GPU that is used for scratch and small tensors
const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES)

// called with a progress value between 0 and 1, pass NULL to disable
// Called with a progress value between 0.0 and 1.0. Pass NULL to disable.
// If the provided progress_callback returns true, model loading continues.
// If it returns false, model loading is immediately aborted.
llama_progress_callback progress_callback;

// context pointer passed to the progress callback
Expand Down