diff --git a/common/common.cpp b/common/common.cpp index 18ffb4e738aee..736d8899a5eee 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -582,41 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector common_get_hf_file(const std::string &, cons #endif // LLAMA_USE_CURL -// -// Batch utils -// - -void common_batch_clear(struct llama_batch & batch) { - batch.n_tokens = 0; -} - -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits) { - GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded"); - - batch.token [batch.n_tokens] = id; - batch.pos [batch.n_tokens] = pos; - batch.n_seq_id[batch.n_tokens] = seq_ids.size(); - for (size_t i = 0; i < seq_ids.size(); ++i) { - batch.seq_id[batch.n_tokens][i] = seq_ids[i]; - } - batch.logits [batch.n_tokens] = logits; - - batch.n_tokens++; -} - // // Token utils // diff --git a/common/common.h b/common/common.h index 1c0f199774976..197108be0ebba 100644 --- a/common/common.h +++ b/common/common.h @@ -516,7 +516,6 @@ void string_process_escapes(std::string & input); std::string string_from(bool value); std::string string_from(const std::vector & values); std::string string_from(const struct llama_context * ctx, const std::vector & tokens); -std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch); // // Filesystem utils @@ -566,19 +565,6 @@ std::pair common_get_hf_file( // clear LoRA adapters from context, then apply new list of adapters void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora); -// -// Batch utils -// - -void common_batch_clear(struct llama_batch & batch); - -void common_batch_add( - struct llama_batch & batch, - llama_token id, - llama_pos pos, - const std::vector & seq_ids, - bool logits); - // // Token utils // diff --git a/common/speculative.cpp b/common/speculative.cpp index ccad70fa9ed85..f16b2c6e36d6c 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -14,7 +14,7 @@ struct common_speculative { struct llama_context * ctx; struct common_sampler * smpl; - llama_batch batch; + llama_batch_ext_ptr batch; llama_tokens prompt; }; @@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init( auto * result = new common_speculative { /* .ctx = */ ctx_dft, /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), + /* .batch = */ llama_batch_ext_ptr(ctx_dft), /* .prompt = */ {}, }; @@ -69,8 +69,6 @@ void common_speculative_free(struct common_speculative * spec) { common_sampler_free(spec->smpl); - llama_batch_free(spec->batch); - delete spec; } @@ -206,40 +204,40 @@ llama_tokens common_speculative_gen_draft( } // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); + batch.clear(); for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); + batch.add_text(prompt_tgt[i], i - i_start, 0, false); prompt.push_back(prompt_tgt[i]); } // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { + if (batch.n_tokens() > 0) { //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); } const llama_pos n_past = prompt.size(); LOG_DBG("%s: n_past = %d\n", __func__, n_past); - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); + batch.clear(); + batch.add_text(id_last, n_past, 0, true); prompt.push_back(id_last); //LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str()); - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); common_sampler_reset(smpl); // sample n_draft tokens from the draft model for (int i = 0; i < params.n_draft; ++i) { - common_batch_clear(batch); + batch.clear(); common_sampler_sample(smpl, ctx, 0, true); @@ -266,10 +264,10 @@ llama_tokens common_speculative_gen_draft( break; } - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + batch.add_text( id, n_past + i + 1, 0, true); // evaluate the drafted tokens on the draft model - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); prompt.push_back(id); } diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 430e8be512653..331ec88852733 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -59,24 +59,17 @@ int main(int argc, char ** argv) { const int32_t n_kv_max = llama_n_ctx(ctx); - llama_batch batch = llama_batch_init(n_kv_max, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); // decode in batches of ctx_params.n_batch tokens - auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) { - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); + auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) { + const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch); + for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) { + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i)); + + llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens)); + + const int ret = llama_decode_ext(ctx, batch_view.get()); if (ret != 0) { LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); return false; @@ -91,7 +84,8 @@ int main(int argc, char ** argv) { // warm up { for (int i = 0; i < 16; ++i) { - common_batch_add(batch, 0, i, { 0 }, false); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -121,14 +115,14 @@ int main(int argc, char ** argv) { continue; } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < pp; ++i) { for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) { - common_batch_add(batch, 0, i, { j }, false); + llama_batch_ext_add_text(batch, 0, i, &j, 1, false); } } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); const auto t_pp_start = ggml_time_us(); @@ -150,10 +144,10 @@ int main(int argc, char ** argv) { const auto t_tg_start = ggml_time_us(); for (int i = 0; i < tg; ++i) { - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int j = 0; j < pl; ++j) { - common_batch_add(batch, 0, pp + i, { j }, true); + llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true); } if (!decode_helper(ctx, batch, ctx_params.n_batch)) { @@ -191,7 +185,7 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_free(ctx); llama_model_free(model); diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 21b95ef5e4e83..204544f6f38cd 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -102,7 +102,7 @@ int main(int argc, char ** argv) { // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext * batch = llama_batch_ext_init(ctx); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -111,12 +111,12 @@ int main(int argc, char ** argv) { // evaluate the initial prompt for (size_t i = 0; i < tokens_list.size(); ++i) { - common_batch_add(batch, tokens_list[i], i, seq_ids, false); + llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false); } - GGML_ASSERT(batch.n_tokens == (int) tokens_list.size()); + GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size()); if (llama_model_has_encoder(model)) { - if (llama_encode(ctx, batch)) { + if (llama_encode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -126,14 +126,14 @@ int main(int argc, char ** argv) { decoder_start_token_id = llama_vocab_bos(vocab); } - common_batch_clear(batch); - common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false); + llama_batch_ext_clear(batch); + llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false); } // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -155,16 +155,16 @@ int main(int argc, char ** argv) { // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1); - int n_cur = batch.n_tokens; + int n_cur = llama_batch_ext_get_n_tokens(batch); int n_decode = 0; const auto t_main_start = ggml_time_us(); while (n_cur <= n_predict) { // prepare the next batch - common_batch_clear(batch); + llama_batch_ext_clear(batch); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -193,23 +193,23 @@ int main(int argc, char ** argv) { streams[i] += common_token_to_piece(ctx, new_token_id); - i_batch[i] = batch.n_tokens; + i_batch[i] = llama_batch_ext_get_n_tokens(batch); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_cur, { i }, true); + llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true); n_decode += 1; } // all streams are finished - if (batch.n_tokens == 0) { + if (llama_batch_ext_get_n_tokens(batch) == 0) { break; } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -234,7 +234,7 @@ int main(int argc, char ** argv) { fprintf(stderr, "\n"); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_sampler_free(smpl); llama_free(ctx); diff --git a/examples/cvector-generator/cvector-generator.cpp b/examples/cvector-generator/cvector-generator.cpp index 2a907155010cb..5b7a42025d8b4 100644 --- a/examples/cvector-generator/cvector-generator.cpp +++ b/examples/cvector-generator/cvector-generator.cpp @@ -343,7 +343,8 @@ static bool cb_eval(struct ggml_tensor * t, bool ask, void * user_data) { static bool get_hidden_layers(llama_context * ctx, std::vector & tokens) { llama_kv_self_clear(ctx); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval\n", __func__); return false; } diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 6f08904159fd5..91b4579f2c747 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -26,56 +26,45 @@ static std::vector split_lines(const std::string & s, const std::st return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { - size_t n_tokens = tokens.size(); - for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); - } -} - -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); - const struct llama_model * model = llama_get_model(ctx); + const llama_model * model = llama_get_model(ctx); // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); + const int n_tokens = llama_batch_ext_get_n_tokens(batch); + // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, n_tokens, n_seq); if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { // encoder-only model - if (llama_encode(ctx, batch) < 0) { + if (llama_encode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to encode\n", __func__); } } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { // decoder-only model - if (llama_decode(ctx, batch) < 0) { + if (llama_decode_ext(ctx, batch) < 0) { LOG_ERR("%s : failed to decode\n", __func__); } } - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; - } - - const float * embd = nullptr; - int embd_pos = 0; - - if (pooling_type == LLAMA_POOLING_TYPE_NONE) { - // try to get token embeddings - embd = llama_get_embeddings_ith(ctx, i); - embd_pos = i; + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + for (int i = 0; i < n_tokens; i++) { + const float * embd = llama_get_embeddings_ith(ctx, i); GGML_ASSERT(embd != NULL && "failed to get token embeddings"); - } else { - // try to get sequence embeddings - supported only when pooling_type is not NONE - embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - embd_pos = batch.seq_id[i][0]; - GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + i * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } + } else { + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); - float * out = output + embd_pos * n_embd; - common_embd_normalize(embd, out, n_embd, embd_norm); + float * out = output + s * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); + } } } @@ -171,7 +160,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_prompts = prompts.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext_ptr batch(ctx); // count number of embeddings int n_embd_count = 0; @@ -198,22 +187,21 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); - e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; + if (batch.n_tokens() + n_toks > n_batch) { + batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize); + batch.clear(); + + e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens() : s; s = 0; - common_batch_clear(batch); } // add to batch - batch_add_seq(batch, inp, s); + batch.add_seq(inp, 0, s, true); s += 1; } // final batch - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + batch_decode(ctx, batch.get(), emb + e * n_embd, s, n_embd, params.embd_normalize); if (params.embd_out.empty()) { LOG("\n"); @@ -319,7 +307,6 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(batch); llama_backend_free(); return 0; diff --git a/examples/eval-callback/eval-callback.cpp b/examples/eval-callback/eval-callback.cpp index fb188f5a9e132..4253b5ca4193e 100644 --- a/examples/eval-callback/eval-callback.cpp +++ b/examples/eval-callback/eval-callback.cpp @@ -134,7 +134,8 @@ static bool run(llama_context * ctx, const common_params & params) { std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); - if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index f7db7861c1ad5..e8753f2163e01 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -1,6 +1,7 @@ #include "arg.h" #include "common.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -13,10 +14,10 @@ static std::vector> encode(llama_context * ctx, const std::ve const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); - llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext_ptr batch(ctx); for (uint64_t i = 0; i < sentences.size(); i++) { - common_batch_clear(batch); + batch.clear(); const std::string input_string = instruction + sentences[i]; @@ -41,7 +42,7 @@ static std::vector> encode(llama_context * ctx, const std::ve // add input to batch (this increments n_tokens) for (int32_t j = 0; j < n_toks; j++) { - common_batch_add(batch, inputs[j], j, { 0 }, j >= n_inst); + batch.add_text(inputs[j], j, 0, j >= n_inst); } // clear previous kv_cache values (irrelevant for embeddings) @@ -50,7 +51,7 @@ static std::vector> encode(llama_context * ctx, const std::ve llama_set_causal_attn(ctx, false); // run model - llama_decode(ctx, batch); + llama_decode_ext(ctx, batch.get()); // get embedding dimensions uint64_t n_embd = llama_model_n_embd(model); @@ -89,8 +90,6 @@ static std::vector> encode(llama_context * ctx, const std::ve #endif } - llama_batch_free(batch); - return result; } @@ -106,25 +105,25 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std llama_set_embeddings(ctx, false); llama_set_causal_attn(ctx, true); - llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); + llama_batch_ext_ptr batch(ctx); std::vector inputs = common_tokenize(vocab, prompt, false, true); int32_t i_current_token = 0; while (true) { - common_batch_clear(bat); + batch.clear(); { const int32_t n_inputs = inputs.size(); for (int32_t i = 0; i < n_inputs; i++) { - common_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1); + batch.add_text(inputs[i], i_current_token++, 0, i == n_inputs - 1); } } inputs.clear(); - llama_decode(ctx, bat); + llama_decode_ext(ctx, batch.get()); - llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1); + llama_token token = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1); if (token == eos_token) { break; @@ -145,8 +144,6 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std std::printf("\n"); } - llama_batch_free(bat); - return result; } diff --git a/examples/imatrix/imatrix.cpp b/examples/imatrix/imatrix.cpp index 31b675e8f90b9..55cecad84d97a 100644 --- a/examples/imatrix/imatrix.cpp +++ b/examples/imatrix/imatrix.cpp @@ -497,7 +497,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -511,14 +511,15 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + llama_batch_ext_clear(batch); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + const llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens[batch_start + i], j*n_batch + i, &seq_id, 1, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch)) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch); return false; } @@ -531,7 +532,7 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); + llama_batch_ext_free(batch); const auto t_end = std::chrono::high_resolution_clock::now(); diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index 4e2f7b7270003..9d5c129581c6f 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -353,7 +353,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/llama-bench/llama-bench.cpp b/examples/llama-bench/llama-bench.cpp index cbcbfcee861ee..e956baf15d263 100644 --- a/examples/llama-bench/llama-bench.cpp +++ b/examples/llama-bench/llama-bench.cpp @@ -1427,7 +1427,7 @@ struct sql_printer : public printer { } }; -static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_threads) { +static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_batch, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1444,14 +1444,15 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_batch, int n_th for (int i = 1; i < n_tokens; i++) { tokens[i] = std::rand() % n_vocab; } - llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens)); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), n_tokens, n_past + n_processed, 0, true); + llama_decode_ext(ctx, batch.get()); n_processed += n_tokens; } llama_synchronize(ctx); } -static void test_gen(llama_context * ctx, int n_gen, int n_threads) { +static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) { llama_set_n_threads(ctx, n_threads, n_threads); const llama_model * model = llama_get_model(ctx); @@ -1461,7 +1462,8 @@ static void test_gen(llama_context * ctx, int n_gen, int n_threads) { llama_token token = llama_vocab_get_add_bos(vocab) ? llama_vocab_bos(vocab) : std::rand() % n_vocab; for (int i = 0; i < n_gen; i++) { - llama_decode(ctx, llama_batch_get_one(&token, 1)); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &token, 1, n_past + i, 0, true); + llama_decode_ext(ctx, batch.get()); llama_synchronize(ctx); token = std::rand() % n_vocab; } @@ -1608,13 +1610,13 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup prompt run\n", params_idx, params_count); } //test_prompt(ctx, std::min(t.n_batch, std::min(t.n_prompt, 32)), 0, t.n_batch, t.n_threads); - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: warmup generation run\n", params_idx, params_count); } - test_gen(ctx, 1, t.n_threads); + test_gen(ctx, 1, 0, t.n_threads); } for (int i = 0; i < params.reps; i++) { @@ -1627,14 +1629,14 @@ int main(int argc, char ** argv) { fprintf(stderr, "llama-bench: benchmark %d/%zu: prompt run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_prompt(ctx, t.n_prompt, t.n_batch, t.n_threads); + test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads); } if (t.n_gen > 0) { if (params.progress) { fprintf(stderr, "llama-bench: benchmark %d/%zu: generation run %d/%d\n", params_idx, params_count, i + 1, params.reps); } - test_gen(ctx, t.n_gen, t.n_threads); + test_gen(ctx, t.n_gen, t.n_prompt, t.n_threads); } uint64_t t_ns = get_time_ns() - t_start; diff --git a/examples/llama.android/app/src/main/AndroidManifest.xml b/examples/llama.android/app/src/main/AndroidManifest.xml index 41a358a299154..de94ee3cd1faf 100644 --- a/examples/llama.android/app/src/main/AndroidManifest.xml +++ b/examples/llama.android/app/src/main/AndroidManifest.xml @@ -3,6 +3,8 @@ xmlns:tools="http://schemas.android.com/tools"> + + (context_pointer); const auto model = reinterpret_cast(model_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const int n_ctx = llama_n_ctx(context); @@ -186,19 +186,20 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( for (nri = 0; nri < nr; nri++) { LOGi("Benchmark prompt processing (pp)"); - common_batch_clear(*batch); + llama_batch_ext_clear(batch); const int n_tokens = pp; for (i = 0; i < n_tokens; i++) { - common_batch_add(*batch, 0, i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false); } - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); llama_kv_self_clear(context); const auto t_pp_start = ggml_time_us(); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during prompt processing"); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during prompt processing"); } const auto t_pp_end = ggml_time_us(); @@ -210,14 +211,15 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( const auto t_tg_start = ggml_time_us(); for (i = 0; i < tg; i++) { - common_batch_clear(*batch); + llama_batch_ext_clear(batch); for (j = 0; j < pl; j++) { - common_batch_add(*batch, 0, i, { j }, true); + llama_seq_id seq_id = j; + llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, true); } - LOGi("llama_decode() text generation: %d", i); - if (llama_decode(context, *batch) != 0) { - LOGi("llama_decode() failed during text generation"); + LOGi("llama_decode_ext() text generation: %d", i); + if (llama_decode_ext(context, batch) != 0) { + LOGi("llama_decode_ext() failed during text generation"); } } @@ -271,33 +273,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model( extern "C" JNIEXPORT jlong JNICALL -Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, jint embd, jint n_seq_max) { - - // Source: Copy of llama.cpp:llama_batch_init but heap-allocated. - - llama_batch *batch = new llama_batch { - 0, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - nullptr, - }; - - if (embd) { - batch->embd = (float *) malloc(sizeof(float) * n_tokens * embd); - } else { - batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens); - } - - batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens); - batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens); - batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * n_tokens); - for (int i = 0; i < n_tokens; ++i) { - batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); - } - batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens); +Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jlong context_pointer) { + const auto context = reinterpret_cast(context_pointer); + llama_batch_ext * batch = llama_batch_ext_init(context); return reinterpret_cast(batch); } @@ -305,9 +283,7 @@ Java_android_llama_cpp_LLamaAndroid_new_1batch(JNIEnv *, jobject, jint n_tokens, extern "C" JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_free_1batch(JNIEnv *, jobject, jlong batch_pointer) { - //llama_batch_free(*reinterpret_cast(batch_pointer)); - const auto batch = reinterpret_cast(batch_pointer); - delete batch; + llama_batch_ext_free(reinterpret_cast(batch_pointer)); } extern "C" @@ -355,7 +331,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( const auto text = env->GetStringUTFChars(jtext, 0); const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); bool parse_special = (format_chat == JNI_TRUE); const auto tokens_list = common_tokenize(context, text, true, parse_special); @@ -363,7 +339,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( auto n_ctx = llama_n_ctx(context); auto n_kv_req = tokens_list.size() + n_len; - LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", n_len, n_ctx, n_kv_req); + LOGi("n_len = %d, n_ctx = %d, n_kv_req = %d", (int) n_len, (int) n_ctx, (int) n_kv_req); if (n_kv_req > n_ctx) { LOGe("error: n_kv_req > n_ctx, the required KV cache size is not big enough"); @@ -373,23 +349,24 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init( LOGi("token: `%s`-> %d ", common_token_to_piece(context, id).c_str(), id); } - common_batch_clear(*batch); + llama_batch_ext_clear(batch); // evaluate the initial prompt for (auto i = 0; i < tokens_list.size(); i++) { - common_batch_add(*batch, tokens_list[i], i, { 0 }, false); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, tokens_list[i], i, &seq_id, 1, false); } // llama_decode will output logits only for the last token of the prompt - batch->logits[batch->n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() failed"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() failed"); } env->ReleaseStringUTFChars(jtext, text); - return batch->n_tokens; + return llama_batch_ext_get_n_tokens(batch); } extern "C" @@ -404,7 +381,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( jobject intvar_ncur ) { const auto context = reinterpret_cast(context_pointer); - const auto batch = reinterpret_cast(batch_pointer); + const auto batch = reinterpret_cast(batch_pointer); const auto sampler = reinterpret_cast(sampler_pointer); const auto model = llama_get_model(context); const auto vocab = llama_model_get_vocab(model); @@ -433,13 +410,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop( new_token = env->NewStringUTF(""); } - common_batch_clear(*batch); - common_batch_add(*batch, new_token_id, n_cur, { 0 }, true); + llama_batch_ext_clear(batch); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch, new_token_id, n_cur, &seq_id, 1, true); env->CallVoidMethod(intvar_ncur, la_int_var_inc); - if (llama_decode(context, *batch) != 0) { - LOGe("llama_decode() returned null"); + if (llama_decode_ext(context, batch) != 0) { + LOGe("llama_decode_ext() returned null"); } return new_token; diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index b964d93e37819..f58f7431a3ca6 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -45,7 +45,7 @@ class LLamaAndroid { private external fun free_context(context: Long) private external fun backend_init(numa: Boolean) private external fun backend_free() - private external fun new_batch(nTokens: Int, embd: Int, nSeqMax: Int): Long + private external fun new_batch(context: Long): Long private external fun free_batch(batch: Long) private external fun new_sampler(): Long private external fun free_sampler(sampler: Long) @@ -102,7 +102,7 @@ class LLamaAndroid { val context = new_context(model) if (context == 0L) throw IllegalStateException("new_context() failed") - val batch = new_batch(512, 0, 1) + val batch = new_batch(context) if (batch == 0L) throw IllegalStateException("new_batch() failed") val sampler = new_sampler() diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index f6e31abc93c09..c4ddaf9bc51c3 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -5,35 +5,19 @@ enum LlamaError: Error { case couldNotInitializeContext } -func llama_batch_clear(_ batch: inout llama_batch) { - batch.n_tokens = 0 -} - -func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) { - batch.token [Int(batch.n_tokens)] = id - batch.pos [Int(batch.n_tokens)] = pos - batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count) - for i in 0.. - private var batch: llama_batch + private var batch: OpaquePointer private var tokens_list: [llama_token] var is_done: Bool = false /// This variable is used to store temporarily invalid cchars private var temporary_invalid_cchars: [CChar] - var n_len: Int32 = 1024 + var n_len: Int32 = 128 var n_cur: Int32 = 0 var n_decode: Int32 = 0 @@ -42,7 +26,7 @@ actor LlamaContext { self.model = model self.context = context self.tokens_list = [] - self.batch = llama_batch_init(512, 0, 1) + self.batch = llama_batch_ext_init(context) self.temporary_invalid_cchars = [] let sparams = llama_sampler_chain_default_params() self.sampling = llama_sampler_chain_init(sparams) @@ -53,7 +37,7 @@ actor LlamaContext { deinit { llama_sampler_free(sampling) - llama_batch_free(batch) + llama_batch_ext_free(batch) llama_model_free(model) llama_free(context) llama_backend_free() @@ -111,7 +95,7 @@ actor LlamaContext { } func get_n_tokens() -> Int32 { - return batch.n_tokens; + return llama_batch_ext_get_n_tokens(batch) } func completion_init(text: String) { @@ -133,25 +117,25 @@ actor LlamaContext { print(String(cString: token_to_piece(token: id) + [0])) } - llama_batch_clear(&batch) + llama_batch_ext_clear(batch) for i1 in 0.. String { var new_token_id: llama_token = 0 - new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1) + new_token_id = llama_sampler_sample(sampling, context, llama_batch_ext_get_n_tokens(batch) - 1) if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len { print("\n") @@ -178,13 +162,13 @@ actor LlamaContext { print(new_token_str) // tokens_list.append(new_token_id) - llama_batch_clear(&batch) - llama_batch_add(&batch, new_token_id, n_cur, [0], true) + llama_batch_ext_clear(batch) + llama_batch_ext_add_text(batch, new_token_id, n_cur, [llama_seq_id(0)], 1, true) n_decode += 1 n_cur += 1 - if llama_decode(context, batch) != 0 { + if llama_decode_ext(context, batch) != 0 { print("failed to evaluate llama!") } @@ -201,21 +185,21 @@ actor LlamaContext { for _ in 0.. pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - decode_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - static int eval_text(gemma3_context & ctx, std::string input, bool logits_last = false) { llama_tokens tokens = common_tokenize(ctx.lctx, input, false, true); - common_batch_clear(ctx.batch); + ctx.batch.clear(); for (llama_token & t : tokens) { - common_batch_add(ctx.batch, t, ctx.n_past++, {0}, false); + ctx.batch.add_text(t, ctx.n_past++, 0, false); } if (logits_last) { - ctx.batch.logits[ctx.batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(ctx.batch.get()); } // LOG("eval_text (n_tokens = %d): %s\n", (int)tokens.size(), input.c_str()); - if (llama_decode(ctx.lctx, ctx.batch)) { + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("Failed to decode text\n"); return 1; } @@ -179,8 +147,9 @@ static int eval_image(gemma3_context & ctx, std::string & fname) { int64_t t1 = ggml_time_ms(); eval_text(ctx, ""); llama_set_causal_attn(ctx.lctx, false); - decode_embd_batch batch_img(image_embd_v.data(), n_tokens, ctx.n_past, 0); - if (llama_decode(ctx.lctx, batch_img.batch)) { + llama_batch_ext_ptr batch_img = llama_batch_ext_ptr::init_from_embd( + ctx.lctx, image_embd_v.data(), n_tokens, n_embd, ctx.n_past, 0); + if (llama_decode_ext(ctx.lctx, batch_img.get())) { LOG_ERR("failed to decode image\n"); return 1; } @@ -210,9 +179,9 @@ static int generate_response(gemma3_context & ctx, common_sampler * smpl, int n_ fflush(stdout); // eval the token - common_batch_clear(ctx.batch); - common_batch_add(ctx.batch, token_id, ctx.n_past++, {0}, true); - if (llama_decode(ctx.lctx, ctx.batch)) { + ctx.batch.clear(); + ctx.batch.add_text(token_id, ctx.n_past++, 0, true); + if (llama_decode_ext(ctx.lctx, ctx.batch.get())) { LOG_ERR("failed to decode token\n"); return 1; } diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 40aa0876f24a7..884547fcb831a 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -20,7 +20,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 518aad3f1f70b..f88e4e7a800b9 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -2,6 +2,7 @@ #include "llava.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -438,39 +439,6 @@ bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, co return true; } -struct llava_embd_batch { - std::vector pos; - std::vector n_seq_id; - std::vector seq_id_0; - std::vector seq_ids; - std::vector logits; - llama_batch batch; - llava_embd_batch(float * embd, int32_t n_tokens, llama_pos pos_0, llama_seq_id seq_id) { - pos .resize(n_tokens); - n_seq_id.resize(n_tokens); - seq_ids .resize(n_tokens + 1); - logits .resize(n_tokens); - seq_id_0.resize(1); - seq_id_0[0] = seq_id; - seq_ids [n_tokens] = nullptr; - batch = { - /*n_tokens =*/ n_tokens, - /*tokens =*/ nullptr, - /*embd =*/ embd, - /*pos =*/ pos.data(), - /*n_seq_id =*/ n_seq_id.data(), - /*seq_id =*/ seq_ids.data(), - /*logits =*/ logits.data(), - }; - for (int i = 0; i < n_tokens; i++) { - batch.pos [i] = pos_0 + i; - batch.n_seq_id[i] = 1; - batch.seq_id [i] = seq_id_0.data(); - batch.logits [i] = false; - } - } -}; - bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) { int n_embd = llama_model_n_embd(llama_get_model(ctx_llama)); @@ -480,8 +448,8 @@ bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_ n_eval = n_batch; } float * embd = image_embed->embed+i*n_embd; - llava_embd_batch llava_batch = llava_embd_batch(embd, n_eval, *n_past, 0); - if (llama_decode(ctx_llama, llava_batch.batch)) { + auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, embd, n_eval, n_embd, 0, 0); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } diff --git a/examples/llava/minicpmv-cli.cpp b/examples/llava/minicpmv-cli.cpp index 12f536cf5cfff..7aadca9489ab5 100644 --- a/examples/llava/minicpmv-cli.cpp +++ b/examples/llava/minicpmv-cli.cpp @@ -101,7 +101,8 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - if (llama_decode(ctx_llama, llama_batch_get_one(&tokens[i], n_eval))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx_llama, &tokens[i], n_eval, *n_past, 0, true); + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/llava/qwen2vl-cli.cpp b/examples/llava/qwen2vl-cli.cpp index 132a7da543c2a..327831c2434f7 100644 --- a/examples/llava/qwen2vl-cli.cpp +++ b/examples/llava/qwen2vl-cli.cpp @@ -66,17 +66,20 @@ static bool qwen2vl_eval_image_embed(llama_context * ctx_llama, const struct lla memcpy(&batch_mrope_pos[n_eval * 2], &mrope_pos[img_tokens * 2 + processed], n_eval * sizeof(llama_pos)); memcpy(&batch_mrope_pos[n_eval * 3], &mrope_pos[img_tokens * 3 + processed], n_eval * sizeof(llama_pos)); - llama_batch batch = { - int32_t(n_eval), // n_tokens - nullptr, // token - (image_embed->embed+i*n_embd), // embed - batch_mrope_pos.data(), // pos - nullptr, // n_seq_id - nullptr, // seq_id - nullptr, // logits - }; - - if (llama_decode(ctx_llama, batch)) { + // tranpose from layout 0123012301230123 to 0000111122223333 + // TODO @ngxson : this is a low-effort solution, generated with the help of LLM; we should improve this in the future + std::vector batch_mrope_pos_T(n_eval * 4); + for (int r = 0; r < 4; r++) { + for (int c = 0; c < n_eval; c++) { + batch_mrope_pos_T[c*4 + r] = batch_mrope_pos[r*n_eval + c]; + } + } + + float * batch_embd = image_embed->embed+i*n_embd; + const llama_pos * pos = batch_mrope_pos_T.data(); + auto batch = llama_batch_ext_ptr::init_from_embd(ctx_llama, batch_embd, n_eval, n_embd, pos, 0); + + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return false; } @@ -95,16 +98,15 @@ static bool eval_tokens(struct llama_context * ctx_llama, std::vector n_batch) { n_eval = n_batch; } - auto batch = llama_batch_get_one(&tokens[i], n_eval); - // TODO: add mrope pos ids somewhere else - pos.resize(batch.n_tokens * 4); - std::fill(pos.begin(), pos.end(), 0); - for (int j = 0; j < batch.n_tokens * 3; j ++) { - pos[j] = *st_pos_id + (j % batch.n_tokens); + + llama_batch_ext_ptr batch(ctx_llama); + for (int j = 0; j < n_eval; j++) { + llama_token token = tokens[i + j]; + batch.add_text(token, *st_pos_id + i + j, 0, false); } - batch.pos = pos.data(); + llama_batch_ext_set_output_last(batch.get()); - if (llama_decode(ctx_llama, batch)) { + if (llama_decode_ext(ctx_llama, batch.get())) { LOG_ERR("%s : failed to eval. token %d/%d (batch size %d, n_past %d)\n", __func__, i, N, n_batch, *n_past); return false; } diff --git a/examples/lookahead/lookahead.cpp b/examples/lookahead/lookahead.cpp index 7df20aee17046..c884238d6ecbe 100644 --- a/examples/lookahead/lookahead.cpp +++ b/examples/lookahead/lookahead.cpp @@ -92,8 +92,10 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true); + llama_decode_ext(ctx, batch0.get()); + llama_decode_ext(ctx, batch1.get()); for (int s = 1; s < W + G + 1; ++s) { llama_kv_self_seq_cp(ctx, 0, s, -1, -1); @@ -115,7 +117,7 @@ int main(int argc, char ** argv) { // seq_id == 0 : the current input token // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [W + 1, W + G] : verification n-grams - llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); + llama_batch_ext * batch = llama_batch_ext_init(ctx); // target model sampling context struct common_sampler * smpl = common_sampler_init(model, params.sampling); @@ -204,10 +206,10 @@ int main(int argc, char ** argv) { // V V V V V V // id { - common_batch_clear(batch); + llama_batch_ext_clear(batch); // current token - first token of the first level - common_batch_add(batch, id, n_past, seq_id_all, true); + llama_batch_ext_add_text(batch, id, n_past, seq_id_all.data(), seq_id_all.size(), true); // verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation { @@ -230,9 +232,10 @@ int main(int argc, char ** argv) { const llama_token t = ngrams_observed.tokens[idx + j]; ngrams_cur[g].tokens [j + 1] = t; - ngrams_cur[g].i_batch[j + 1] = batch.n_tokens; + ngrams_cur[g].i_batch[j + 1] = llama_batch_ext_get_n_tokens(batch); - common_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true); + llama_seq_id seq_id = W + 1 + g; + llama_batch_ext_add_text(batch, t, n_past + j + 1, &seq_id, 1, true); } } } @@ -244,18 +247,20 @@ int main(int argc, char ** argv) { seq_id_look[j] = i + j + 1; } - common_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false); + llama_batch_ext_add_text(batch, tokens_j[0][i], n_past + i, + seq_id_look.data(), seq_id_look.size(), false); } // fill the rest of the levels for (int j = 1; j < N - 1; j++) { for (int i = 0; i < W; i++) { - common_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2); + llama_seq_id seq_id = i + 1; + llama_batch_ext_add_text(batch, tokens_j[j][i], n_past + j + i, &seq_id, 1, j == N - 2); } } } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch) != 0) { LOG_ERR("\n\n%s: llama_decode failed - increase KV cache size\n", __func__); return 1; } @@ -475,7 +480,7 @@ int main(int argc, char ** argv) { llama_kv_cache_view_free(&kvc_view); - llama_batch_free(batch); + llama_batch_ext_free(batch); llama_backend_free(); diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 4ae93b2a5ed15..9aa1bff74037c 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -5,6 +5,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -91,8 +92,10 @@ int main(int argc, char ** argv){ const auto t_enc_start = ggml_time_us(); - llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx, llama_batch_get_one(&inp.back(), 1)); + auto batch0 = llama_batch_ext_ptr::init_from_text(ctx, inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(ctx, &inp.back(), 1, n_input - 1, 0, true); + llama_decode_ext(ctx, batch0.get()); + llama_decode_ext(ctx, batch1.get()); const auto t_enc_end = ggml_time_us(); @@ -108,7 +111,7 @@ int main(int argc, char ** argv){ std::vector draft; - llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); + llama_batch_ext_ptr batch_tgt(ctx); // debug struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, 1); @@ -194,8 +197,8 @@ int main(int argc, char ** argv){ // clean the cache of draft tokens that weren't accepted llama_kv_self_seq_rm(ctx, 0, n_past, -1); - common_batch_clear(batch_tgt); - common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true); + batch_tgt.clear(); + batch_tgt.add_text(draft[0], n_past, 0, true); // Draft already contains a single token sampled from the model: GGML_ASSERT(draft.size() == 1); @@ -205,13 +208,13 @@ int main(int argc, char ** argv){ common_ngram_cache_draft(inp, draft, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, ngram_cache_context, ngram_cache_dynamic, ngram_cache_static); for (size_t i = 1; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + batch_tgt.add_text(draft[i], n_past + i, 0, true); } t_draft_us += ggml_time_us() - t_start_draft_us; n_drafted += draft.size() - 1; - llama_decode(ctx, batch_tgt); + llama_decode_ext(ctx, batch_tgt.get()); ++n_past; draft.erase(draft.begin()); @@ -243,8 +246,6 @@ int main(int argc, char ** argv){ common_sampler_free(smpl); - llama_batch_free(batch_tgt); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index fd7410a646c69..84bfefba2a13c 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -548,7 +548,8 @@ int main(int argc, char ** argv) { int enc_input_size = embd_inp.size(); llama_token * enc_input_buf = embd_inp.data(); - if (llama_encode(ctx, llama_batch_get_one(enc_input_buf, enc_input_size))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx, enc_input_buf, enc_input_size, 0, 0, true); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } @@ -668,7 +669,8 @@ int main(int argc, char ** argv) { LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval))) { + auto batch = llama_batch_ext_ptr::init_from_text(ctx, &embd[i], n_eval, n_past, 0, true); + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); return 1; } diff --git a/examples/parallel/parallel.cpp b/examples/parallel/parallel.cpp index 588632f0432b2..ee713ab416645 100644 --- a/examples/parallel/parallel.cpp +++ b/examples/parallel/parallel.cpp @@ -6,6 +6,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -157,8 +158,6 @@ int main(int argc, char ** argv) { LOG_INF("\n\n"); - const int n_ctx = llama_n_ctx(ctx); - std::vector clients(n_clients); for (size_t i = 0; i < clients.size(); ++i) { auto & client = clients[i]; @@ -174,7 +173,7 @@ int main(int argc, char ** argv) { // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time - llama_batch batch = llama_batch_init(n_ctx, 0, 1); + llama_batch_ext_ptr batch(ctx); int32_t n_total_prompt = 0; int32_t n_total_gen = 0; @@ -192,10 +191,10 @@ int main(int argc, char ** argv) { LOG_INF("%s: Evaluating the system prompt ...\n", __func__); for (int32_t i = 0; i < n_tokens_system; ++i) { - common_batch_add(batch, tokens_system[i], i, { 0 }, false); + batch.add_text(tokens_system[i], i, 0, false); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -216,7 +215,7 @@ int main(int argc, char ** argv) { common_kv_cache_dump_view_seqs(kvc_view, 40); } - common_batch_clear(batch); + batch.clear(); // decode any currently ongoing sequences for (auto & client : clients) { @@ -224,14 +223,15 @@ int main(int argc, char ** argv) { continue; } - client.i_batch = batch.n_tokens; + client.i_batch = batch.n_tokens(); - common_batch_add(batch, client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, { client.id + 1 }, true); + llama_seq_id seq_id = client.id + 1; + batch.add_text(client.sampled, n_tokens_system + client.n_prompt + client.n_decoded, seq_id, true); client.n_decoded += 1; } - if (batch.n_tokens == 0) { + if (batch.n_tokens() == 0) { // all sequences have ended - clear the entire KV cache for (int i = 1; i <= n_clients; ++i) { llama_kv_self_seq_rm(ctx, i, -1, -1); @@ -243,7 +243,7 @@ int main(int argc, char ** argv) { } // insert new sequences for decoding - if (cont_batching || batch.n_tokens == 0) { + if (cont_batching || batch.n_tokens() == 0) { for (auto & client : clients) { if (client.seq_id == -1 && g_seq_id < n_seq) { client.seq_id = g_seq_id; @@ -262,17 +262,18 @@ int main(int argc, char ** argv) { tokens_prompt = common_tokenize(ctx, client.prompt, false); for (size_t i = 0; i < tokens_prompt.size(); ++i) { - common_batch_add(batch, tokens_prompt[i], i + n_tokens_system, { client.id + 1 }, false); + llama_seq_id seq_id = client.id + 1; + batch.add_text(tokens_prompt[i], i + n_tokens_system, seq_id, false); } // extract the logits only for the last token - if (batch.n_tokens > 0) { - batch.logits[batch.n_tokens - 1] = true; + if (batch.n_tokens() > 0) { + llama_batch_ext_set_output_last(batch.get()); } client.n_prompt = tokens_prompt.size(); client.n_decoded = 0; - client.i_batch = batch.n_tokens - 1; + client.i_batch = batch.n_tokens() - 1; LOG_INF("\033[31mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id); @@ -286,14 +287,15 @@ int main(int argc, char ** argv) { } } - if (batch.n_tokens == 0) { + if (batch.n_tokens() == 0) { break; } // process in chunks of params.n_batch int32_t n_batch = params.n_batch; - for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { + int32_t n_tokens_in_batch = batch.n_tokens(); + for (int32_t i = 0; i < (int32_t) n_tokens_in_batch; i += n_batch) { // experiment: process in powers of 2 //if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) { // n_batch /= 2; @@ -301,19 +303,11 @@ int main(int argc, char ** argv) { // continue; //} - const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i)); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + const int32_t n_tokens = std::min(n_batch, (int32_t) (n_tokens_in_batch - i)); - const int ret = llama_decode(ctx, batch_view); + llama_batch_ext * batch_view = llama_batch_ext_get_view(batch.get(), i, n_tokens); + const int ret = llama_decode_ext(ctx, batch_view); + llama_batch_ext_free(batch_view); if (ret != 0) { if (n_batch == 1 || ret < 0) { // if you get here, it means the KV cache is full - try increasing it via the context size @@ -417,8 +411,6 @@ int main(int argc, char ** argv) { // TODO: print sampling/grammar timings for all clients llama_perf_context_print(ctx); - llama_batch_free(batch); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/passkey/passkey.cpp b/examples/passkey/passkey.cpp index ea3a6c1fca3ee..94ede1c5a91c2 100644 --- a/examples/passkey/passkey.cpp +++ b/examples/passkey/passkey.cpp @@ -2,6 +2,7 @@ #include "common.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -122,7 +123,7 @@ int main(int argc, char ** argv) { LOG_INF("prompt tokens: %d\n", n_tokens_all); //LOG_INF("prompt: %s\n", params.prompt.c_str()); - llama_batch batch = llama_batch_init(params.n_batch, 0, 1); + llama_batch_ext_ptr batch(ctx); int n_past = 0; @@ -140,17 +141,17 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; } - common_batch_clear(batch); + batch.clear(); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + batch.add_text(tokens_list[i + j], n_past++, 0, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_INF("%s: llama_decode() failed\n", __func__); return 1; } @@ -174,17 +175,17 @@ int main(int argc, char ** argv) { n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1; - common_batch_clear(batch); + batch.clear(); for (int j = 0; j < n_batch && i + j < n_tokens_all; j++) { - common_batch_add(batch, tokens_list[i + j], n_past++, { 0 }, false); + batch.add_text(tokens_list[i + j], n_past++, 0, false); } if (i + n_batch >= n_tokens_all) { - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); } - if (llama_decode(ctx, batch) != 0) { + if (llama_decode_ext(ctx, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -223,7 +224,7 @@ int main(int argc, char ** argv) { while (n_cur <= n_len) { // sample the next token { - const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1); + const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens() - 1); // is it an end of generation? if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len) { @@ -237,16 +238,17 @@ int main(int argc, char ** argv) { n_decode += 1; // prepare the next batch - common_batch_clear(batch); + batch.clear(); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past++, { 0 }, true); + llama_seq_id seq_id = 0; + llama_batch_ext_add_text(batch.get(), new_token_id, n_past++, &seq_id, 1, true); } n_cur += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } @@ -266,8 +268,6 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl); - llama_batch_free(batch); - llama_free(ctx); llama_model_free(model); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 8c413f7d66e6d..d0fbc3f571734 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -363,21 +363,20 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext_ptr batch(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } //LOG_DBG(" Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { //LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return {tokens, -1, logit_history, prob_history}; } @@ -397,8 +396,6 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { @@ -504,7 +501,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); GGML_ASSERT(params.n_ctx == n_seq * n_ctx); - llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); + llama_batch_ext_ptr batch(ctx); std::vector logits; if (num_batches > 1) { @@ -555,7 +552,7 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & int n_outputs = 0; - batch.n_tokens = 0; + batch.clear(); for (int seq = 0; seq < n_seq_batch; seq++) { int seq_start = batch_start + seq*n_ctx; @@ -568,22 +565,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & } for (int k = 0; k < batch_size; ++k) { - const int idx = seq*n_ctx + k; - batch.token [idx] = tokens[seq_start + k]; - batch.pos [idx] = j*n_batch + k; - batch.n_seq_id[idx] = 1; - batch.seq_id [idx][0] = seq; - batch.logits [idx] = batch.pos[idx] >= first ? 1 : 0; - - n_outputs += batch.logits[idx] != 0; + const llama_pos pos = j*n_batch + k; + bool output = pos >= first; + batch.add_text(tokens[seq_start + k], pos, seq, output); + + n_outputs += output ? 1 : 0; } - batch.n_tokens += batch_size; // restore the original token in case it was set to BOS tokens[seq_start] = token_org; } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_INF("%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } @@ -653,42 +646,18 @@ static results_perplexity perplexity(llama_context * ctx, const common_params & LOG_ERR("Unexpected negative standard deviation of log(prob)\n"); } - llama_batch_free(batch); - return {tokens, ppl, logit_history, prob_history}; } -static bool decode_helper(llama_context * ctx, llama_batch & batch, std::vector & batch_logits, int n_batch, int n_vocab) { - int prev_outputs = 0; - for (int i = 0; i < (int) batch.n_tokens; i += n_batch) { - const int n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; - - const int ret = llama_decode(ctx, batch_view); - if (ret != 0) { - LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret); - return false; - } - - int n_outputs = 0; - for (int i = 0; i < n_tokens; ++i) { - n_outputs += batch_view.logits[i] != 0; - } - - memcpy(batch_logits.data() + size_t(prev_outputs)*n_vocab, llama_get_logits(ctx), size_t(n_outputs)*n_vocab*sizeof(float)); - - prev_outputs += n_outputs; +static bool decode_helper(llama_context * ctx, llama_batch_ext_ptr & batch, std::vector & batch_logits, size_t n_outputs, int n_vocab) { + const int ret = llama_decode_ext(ctx, batch.get()); + if (ret != 0) { + LOG_ERR("failed to decode the batch, ret = %d\n", ret); + return false; } + memcpy(batch_logits.data(), llama_get_logits(ctx), n_outputs*n_vocab*sizeof(float)); + return true; } @@ -856,14 +825,12 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { double acc = 0.0f; const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 4); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -879,7 +846,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -895,9 +862,10 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { } for (size_t i = 0; i < hs_cur.common_prefix; ++i) { - common_batch_add(batch, hs_cur.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }, false); + std::vector seq_ids = { s0 + 0, s0 + 1, s0 + 2, s0 + 3 }; + batch.add_text(hs_cur.seq_tokens[0][i], i, seq_ids, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 4; ++s) { @@ -905,7 +873,8 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { // TODO: don't evaluate the last token of each sequence for (size_t i = hs_cur.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, hs_cur.seq_tokens[s][i], i, { s0 + s }, needs_logits); + llama_seq_id seq_id = s0 + s; + batch.add_text(hs_cur.seq_tokens[s][i], i, seq_id, needs_logits); n_logits += needs_logits; } } @@ -927,7 +896,7 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -992,8 +961,6 @@ static void hellaswag_score(llama_context * ctx, const common_params & params) { i0 = i1 - 1; } - llama_batch_free(batch); - LOG("\n"); } @@ -1140,14 +1107,12 @@ static void winogrande_score(llama_context * ctx, const common_params & params) LOG_INF("%s : calculating winogrande score over selected tasks.\n", __func__); const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 128; const int max_seq = std::min(2*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, 2); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); // TODO: this could be made smaller; it's currently the worst-case size @@ -1166,7 +1131,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) size_t i1 = i0; size_t i_logits = 0; - common_batch_clear(batch); + batch.clear(); while (n_cur + (int) data[i1].required_tokens <= n_ctx) { int n_logits = 0; @@ -1176,15 +1141,16 @@ static void winogrande_score(llama_context * ctx, const common_params & params) } for (size_t i = 0; i < data[i1].common_prefix; ++i) { - common_batch_add(batch, data[i1].seq_tokens[0][i], i, { s0 + 0, s0 + 1 }, false); + std::vector seq_ids{ s0 + 0, s0 + 1 }; + batch.add_text(data[i1].seq_tokens[0][i], i, seq_ids, false); } - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); n_logits += 1; for (int s = 0; s < 2; ++s) { // TODO: end before the last token, no need to predict past the end of the sequences for (size_t i = data[i1].common_prefix; i < data[i1].seq_tokens[s].size(); ++i) { - common_batch_add(batch, data[i1].seq_tokens[s][i], i, { s0 + s }, true); + batch.add_text(data[i1].seq_tokens[s][i], i, s0 + s, true); n_logits += 1; } } @@ -1206,7 +1172,7 @@ static void winogrande_score(llama_context * ctx, const common_params & params) llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1494,14 +1460,12 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par LOG("\ntask\tacc_norm\n"); const int n_ctx = llama_n_ctx(ctx); - const int n_batch = params.n_batch; - const int n_vocab = llama_vocab_n_tokens(vocab); const int max_tasks_per_batch = 32; const int max_seq = std::min(4*max_tasks_per_batch, (int) llama_n_seq_max(ctx)); - llama_batch batch = llama_batch_init(n_ctx, 0, max_seq); + llama_batch_ext_ptr batch(ctx); std::vector tok_logits(n_vocab); std::vector batch_logits(size_t(n_ctx)*n_vocab); @@ -1521,7 +1485,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par size_t i1 = i0; size_t i_logits = 0; // this tells us how many logits were needed before this point in the batch - common_batch_clear(batch); + batch.clear(); // batch as much tasks as possible into the available context // each task has 4 unique sequence ids - one for each ending @@ -1540,13 +1504,14 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par if (int(batch_indeces.size()) != num_answers) { batch_indeces.resize(num_answers); } - for (int s = 0; s < num_answers; ++s) batch_indeces[s] = s0 + s; + for (int s = 0; s < num_answers; ++s) { + batch_indeces[s] = s0 + s; + } for (size_t i = 0; i < cur_task.common_prefix; ++i) { - //llama_batch_add(batch, cur_task.seq_tokens[0][i], i, { s0 + 0, s0 + 1, s0 + 2, s0 + 3}, false); - common_batch_add(batch, cur_task.seq_tokens[0][i], i, batch_indeces, false); + batch.add_text(cur_task.seq_tokens[0][i], i, batch_indeces, false); } - batch.logits[batch.n_tokens - 1] = true; // we need logits for the last token of the common prefix + llama_batch_ext_set_output_last(batch.get()); // we need logits for the last token of the common prefix n_logits += 1; for (int s = 0; s < int(cur_task.seq_tokens.size()); ++s) { @@ -1554,7 +1519,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par // TODO: don't evaluate the last token of each sequence for (size_t i = cur_task.common_prefix; i < seq_tokens_size; ++i) { const bool needs_logits = i < seq_tokens_size - 1; - common_batch_add(batch, cur_task.seq_tokens[s][i], i, { s0 + s }, needs_logits); + batch.add_text(cur_task.seq_tokens[s][i], i, s0 + s, needs_logits); n_logits += needs_logits; } } @@ -1578,7 +1543,7 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par llama_kv_self_clear(ctx); // decode all tasks [i0, i1) - if (!decode_helper(ctx, batch, batch_logits, n_batch, n_vocab)) { + if (!decode_helper(ctx, batch, batch_logits, i_logits, n_vocab)) { LOG_ERR("%s: llama_decode() failed\n", __func__); return; } @@ -1653,8 +1618,6 @@ static void multiple_choice_score(llama_context * ctx, const common_params & par i0 = i1 - 1; } - llama_batch_free(batch); - if (n_done < 100 && (params.multiple_choice_tasks != 0 && params.multiple_choice_tasks < (size_t)n_task)) return; float p = 1.f*n_correct/n_done; @@ -1767,7 +1730,7 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { // clear the KV cache llama_kv_self_clear(ctx); - llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext_ptr batch(ctx); for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; @@ -1781,14 +1744,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { tokens[batch_start] = llama_vocab_bos(vocab); } - common_batch_clear(batch); + batch.clear(); for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + batch.add_text(tokens[batch_start + i], j*n_batch + i, 0, true); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { LOG_ERR("%s : failed to eval\n", __func__); - llama_batch_free(batch); return; } @@ -1801,8 +1763,6 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } } - llama_batch_free(batch); - const auto t_end = std::chrono::high_resolution_clock::now(); if (i == 0) { diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 0efe20d4b3f5d..00617e059f412 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -74,40 +74,32 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, llama_seq_id seq_id) { - size_t n_tokens = tokens.size(); - for (size_t i = 0; i < n_tokens; i++) { - common_batch_add(batch, tokens[i], i, { seq_id }, true); - } -} +static void batch_decode(llama_context * ctx, llama_batch_ext * batch, float * output, int n_seq, int n_embd, int embd_norm = 2) { + const llama_model * model = llama_get_model(ctx); -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) { // clear previous kv_cache values (irrelevant for embeddings) llama_kv_self_clear(ctx); // run model - LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq); - if (llama_decode(ctx, batch) < 0) { - LOG_ERR("%s : failed to decode\n", __func__); - } - - for (int i = 0; i < batch.n_tokens; i++) { - if (!batch.logits[i]) { - continue; + LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, llama_batch_ext_get_n_tokens(batch), n_seq); + if (llama_model_has_encoder(model) && !llama_model_has_decoder(model)) { + // encoder-only model + if (llama_encode_ext(ctx, batch) < 0) { + LOG_ERR("%s : failed to encode\n", __func__); } - - // try to get sequence embeddings - supported only when pooling_type is not NONE - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - LOG_ERR("%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } + } else if (!llama_model_has_encoder(model) && llama_model_has_decoder(model)) { + // decoder-only model + if (llama_decode_ext(ctx, batch) < 0) { + LOG_ERR("%s : failed to decode\n", __func__); } + } - float * out = output + batch.seq_id[i][0] * n_embd; - common_embd_normalize(embd, out, n_embd, 2); + for (int s = 0; s < n_seq; s++) { + const float * embd = llama_get_embeddings_seq(ctx, s); + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); + + float * out = output + s * n_embd; + common_embd_normalize(embd, out, n_embd, embd_norm); } } @@ -214,7 +206,7 @@ int main(int argc, char ** argv) { // initialize batch const int n_chunks = chunks.size(); - struct llama_batch batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext_ptr batch(ctx); // allocate output const int n_embd = llama_model_n_embd(model); @@ -231,22 +223,21 @@ int main(int argc, char ** argv) { const uint64_t n_toks = inp.size(); // encode if at capacity - if (batch.n_tokens + n_toks > n_batch) { - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); - common_batch_clear(batch); + if (batch.n_tokens() + n_toks > n_batch) { + batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd); + batch.clear(); + p += s; s = 0; } // add to batch - batch_add_seq(batch, inp, s); + batch.add_seq(inp, 0, s, true); s += 1; } // final batch - float * out = emb + p * n_embd; - batch_decode(ctx, batch, out, s, n_embd); + batch_decode(ctx, batch.get(), emb + p * n_embd, s, n_embd); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { @@ -255,7 +246,7 @@ int main(int argc, char ** argv) { chunks[i].tokens.clear(); } - struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); + llama_batch_ext_ptr query_batch(ctx); // start loop, receive query and return top k similar chunks based on cosine similarity std::string query; @@ -264,12 +255,12 @@ int main(int argc, char ** argv) { std::getline(std::cin, query); std::vector query_tokens = common_tokenize(ctx, query, true); - batch_add_seq(query_batch, query_tokens, 0); + batch.add_seq(query_tokens, 0, 0, true); std::vector query_emb(n_embd, 0); - batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); + batch_decode(ctx, query_batch.get(), query_emb.data(), 1, n_embd); - common_batch_clear(query_batch); + query_batch.clear(); // compute cosine similarities { @@ -299,6 +290,5 @@ int main(int argc, char ** argv) { llama_perf_context_print(ctx); // clean up - llama_batch_free(query_batch); llama_backend_free(); } diff --git a/examples/run/run.cpp b/examples/run/run.cpp index 462a6d151933e..876a5a4c0d254 100644 --- a/examples/run/run.cpp +++ b/examples/run/run.cpp @@ -640,6 +640,7 @@ class LlamaData { std::vector messages; // TODO: switch to common_chat_msg std::list msg_strs; std::vector fmtted; + llama_pos n_past = 0; int init(Opt & opt) { model = initialize_model(opt); @@ -950,10 +951,10 @@ static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt } // Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { +static int check_context_size(const llama_context_ptr & ctx, const llama_batch_ext_ptr & batch) { const int n_ctx = llama_n_ctx(ctx.get()); const int n_ctx_used = llama_kv_self_used_cells(ctx.get()); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + batch.n_tokens() > n_ctx) { printf(LOG_COL_DEFAULT "\n"); printe("context size exceeded\n"); return 1; @@ -991,15 +992,17 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); + auto batch = llama_batch_ext_ptr::init_from_text(llama_data.context.get(), tokens.data(), tokens.size(), llama_data.n_past, 0, true); llama_token new_token_id; while (true) { check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { + if (llama_decode_ext(llama_data.context.get(), batch.get())) { printe("failed to decode\n"); return 1; } + llama_data.n_past += batch.n_tokens(); + // sample the next token, check is it an end of generation? new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); if (llama_vocab_is_eog(vocab, new_token_id)) { @@ -1014,7 +1017,7 @@ static int generate(LlamaData & llama_data, const std::string & prompt, std::str print_word_and_concatenate_to_response(piece, response); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + batch.clear(); } printf(LOG_COL_DEFAULT); diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 760ebbbf08788..03b1e7ccfabe4 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -48,15 +48,11 @@ int main(int argc, char ** argv) { auto tokens = common_tokenize(ctx, params.prompt, true); // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); - } - batch.logits[batch.n_tokens - 1] = true; // generate next token + auto batch = llama_batch_ext_ptr::init_from_text(ctx, tokens.data(), tokens.size(), 0, 0, true); // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; + llama_decode_ext(ctx, batch.get()); + n_past += batch.n_tokens(); // save state (rng, logits, embedding and kv_cache) to file { @@ -83,12 +79,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result0 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + batch.clear(); + batch.add_text(next_token, 0, 0, true); - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -135,12 +131,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result1 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {0}, true); + batch.clear(); + batch.add_text(next_token, 0, 0, true); - if (llama_decode(ctx2, batch)) { + if (llama_decode_ext(ctx2, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -216,12 +212,12 @@ int main(int argc, char ** argv) { printf("%s", next_token_str.c_str()); result2 += next_token_str; - common_batch_clear(batch); - common_batch_add(batch, next_token, n_past, {1}, true); + batch.clear(); + batch.add_text(next_token, 0, 1, true); - if (llama_decode(ctx3, batch)) { + if (llama_decode_ext(ctx3, batch.get())) { fprintf(stderr, "\n%s : failed to evaluate\n", __func__); - llama_batch_free(batch); + llama_batch_ext_free(batch.get()); return 1; } n_past += 1; @@ -233,7 +229,7 @@ int main(int argc, char ** argv) { llama_sampler_free(smpl2); llama_sampler_free(smpl3); - llama_batch_free(batch); + llama_batch_ext_free(batch.get()); if (result0 != result2) { fprintf(stderr, "\n%s : error : the seq restore generation is different\n", __func__); diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 18caa9127662d..36ff66c2c3f0f 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -1229,7 +1229,7 @@ struct server_slot { // only used for completion/embedding/infill/rerank server_task_type task_type = SERVER_TASK_TYPE_COMPLETION; - llama_batch batch_spec = {}; + llama_batch_ext_ptr batch_spec; llama_context * ctx = nullptr; llama_context * ctx_dft = nullptr; @@ -1253,7 +1253,7 @@ struct server_slot { int32_t n_past = 0; int32_t n_decoded = 0; int32_t n_remaining = -1; - int32_t i_batch = -1; + int32_t i_batch = -1; // TODO: remove and use only sequence-based sampling int32_t n_predict = -1; // TODO: disambiguate from params.n_predict // n_prompt_tokens may not be equal to prompt_tokens.size(), because prompt maybe truncated @@ -1801,7 +1801,7 @@ struct server_context { llama_context_params cparams_dft; - llama_batch batch = {}; + llama_batch_ext_ptr batch; bool clean_kv_cache = true; bool add_bos_token = true; @@ -1834,11 +1834,7 @@ struct server_context { common_speculative_free(slot.spec); slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); } - - llama_batch_free(batch); } bool load_model(const common_params & params) { @@ -1929,9 +1925,10 @@ struct server_context { slot.ctx = ctx; slot.n_ctx = n_ctx_slot; slot.n_predict = params_base.n_predict; + slot.batch_spec = llama_batch_ext_ptr(ctx); if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); + slot.batch_spec.clear(); slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); if (slot.ctx_dft == nullptr) { @@ -1956,20 +1953,14 @@ struct server_context { slot.reset(); - slots.push_back(slot); + slots.push_back(std::move(slot)); } default_generation_settings_for_props = slots[0].to_json(); // the update_slots() logic will always submit a maximum of n_batch or n_parallel tokens // note that n_batch can be > n_ctx (e.g. for non-causal attention models such as BERT where the KV cache is not used) - { - const int32_t n_batch = llama_n_batch(ctx); - - // only a single seq_id per token is needed - batch = llama_batch_init(std::max(n_batch, params_base.n_parallel), 0, 1); - } - + batch = llama_batch_ext_ptr(ctx); metrics.init(); } @@ -2102,9 +2093,7 @@ struct server_context { } if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(slot.params.speculative.n_max + 1, 0, 1); + slot.batch_spec.clear(); } slot.state = SLOT_STATE_STARTED; @@ -2412,7 +2401,7 @@ struct server_context { queue_results.send(std::move(res)); } - void send_embedding(const server_slot & slot, const llama_batch & batch) { + void send_embedding(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; @@ -2421,33 +2410,40 @@ struct server_context { const int n_embd = llama_model_n_embd(model); - std::vector embd_res(n_embd, 0.0f); + const llama_seq_id seq_id = slot.id; - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } + std::vector embd_res(n_embd, 0.0f); - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } + if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { + const float * embd = llama_get_embeddings_seq(ctx, seq_id); if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); res->embedding.push_back(std::vector(n_embd, 0.0f)); - continue; } - // normalize only when there is pooling // TODO: configurable - if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, 2); - res->embedding.push_back(embd_res); - } else { - res->embedding.push_back({ embd, embd + n_embd }); - } + common_embd_normalize(embd, embd_res.data(), n_embd, 2); + res->embedding.push_back(embd_res); + } else { + GGML_ABORT("embeddings without pooling is not supported yet"); + //for (int i = 0; i < batch.n_tokens(); ++i) { + // auto tok = batch.tokens[i]; + // if (!tok.logits || tok.seq_id != slot.id) { + // continue; + // } + + // const float * embd = llama_get_embeddings_ith(ctx, tok.seq_id); + // if (embd == NULL) { + // SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", tok.token, tok.seq_id); + + // res->embedding.push_back(std::vector(n_embd, 0.0f)); + // continue; + // } + + // res->embedding.push_back({ embd, embd + n_embd }); + //} } SLT_DBG(slot, "%s", "sending embeddings\n"); @@ -2455,29 +2451,20 @@ struct server_context { queue_results.send(std::move(res)); } - void send_rerank(const server_slot & slot, const llama_batch & batch) { + void send_rerank(const server_slot & slot) { auto res = std::make_unique(); res->id = slot.id_task; res->index = slot.index; res->n_tokens = slot.n_prompt_tokens; - for (int i = 0; i < batch.n_tokens; ++i) { - if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { - continue; - } - - const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - } - - if (embd == NULL) { - SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); + const llama_seq_id seq_id = slot.id; - res->score = -1e6; - continue; - } + const float * embd = llama_get_embeddings_seq(ctx, seq_id); + if (embd == NULL) { + SLT_ERR(slot, "failed to get sequence embeddings, seq_id = %d\n", seq_id); + res->score = -1e6; + } else { res->score = embd[0]; } @@ -2863,7 +2850,7 @@ struct server_context { } // start populating the batch for this iteration - common_batch_clear(batch); + batch.clear(); // track if given slot can be batched with slots already in the batch server_slot * slot_batched = nullptr; @@ -2885,9 +2872,9 @@ struct server_context { continue; } - slot.i_batch = batch.n_tokens; + slot.i_batch = batch.n_tokens(); - common_batch_add(batch, slot.sampled, slot.n_past, { slot.id }, true); + batch.add_text(slot.sampled, slot.n_past, slot.id, true); slot.n_past += 1; @@ -2904,7 +2891,7 @@ struct server_context { int32_t n_ubatch = llama_n_ubatch(ctx); // next, batch any pending prompts without exceeding n_batch - if (params_base.cont_batching || batch.n_tokens == 0) { + if (params_base.cont_batching || batch.n_tokens() == 0) { for (auto & slot : slots) { // check if we can batch this slot with the previous one if (slot.is_processing()) { @@ -3070,7 +3057,7 @@ struct server_context { // non-causal tasks require to fit the entire prompt in the physical batch if (slot.is_non_causal()) { // cannot fit the prompt in the current batch - will try next iter - if (batch.n_tokens + slot.n_prompt_tokens > n_batch) { + if (batch.n_tokens() + slot.n_prompt_tokens > n_batch) { continue; } } @@ -3090,11 +3077,11 @@ struct server_context { slot.cache_tokens.resize(slot.n_past); // add prompt tokens for processing in the current batch - while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens < n_batch) { + while (slot.n_past < slot.n_prompt_tokens && batch.n_tokens() < n_batch) { // without pooling, we want to output the embeddings for all the tokens in the batch const bool need_embd = slot.task_type == SERVER_TASK_TYPE_EMBEDDING && llama_pooling_type(slot.ctx) == LLAMA_POOLING_TYPE_NONE; - common_batch_add(batch, prompt_tokens[slot.n_past], slot.n_past, { slot.id }, need_embd); + batch.add_text(prompt_tokens[slot.n_past], slot.n_past, slot.id, need_embd); if (slot.params.cache_prompt) { slot.cache_tokens.push_back(prompt_tokens[slot.n_past]); @@ -3104,13 +3091,14 @@ struct server_context { slot.n_past++; } - SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", slot.n_past, batch.n_tokens, (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); + SLT_INF(slot, "prompt processing progress, n_past = %d, n_tokens = %d, progress = %f\n", + slot.n_past, batch.n_tokens(), (float) slot.n_prompt_tokens_processed / slot.n_prompt_tokens); // entire prompt has been processed if (slot.n_past == slot.n_prompt_tokens) { slot.state = SLOT_STATE_DONE_PROMPT; - GGML_ASSERT(batch.n_tokens > 0); + GGML_ASSERT(batch.n_tokens() > 0); common_sampler_reset(slot.smpl); @@ -3120,27 +3108,27 @@ struct server_context { } // extract the logits only for the last token - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); slot.n_decoded = 0; - slot.i_batch = batch.n_tokens - 1; + slot.i_batch = batch.n_tokens() - 1; - SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens); + SLT_INF(slot, "prompt done, n_past = %d, n_tokens = %d\n", slot.n_past, batch.n_tokens()); } } - if (batch.n_tokens >= n_batch) { + if (batch.n_tokens() >= n_batch) { break; } } } - if (batch.n_tokens == 0) { + if (batch.n_tokens() == 0) { SRV_WRN("%s", "no tokens to decode\n"); return; } - SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens); + SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens()); if (slot_batched) { // make sure we're in the right embedding mode @@ -3150,20 +3138,12 @@ struct server_context { } // process the created batch of tokens - for (int32_t i = 0; i < batch.n_tokens; i += n_batch) { - const int32_t n_tokens = std::min(n_batch, batch.n_tokens - i); - - llama_batch batch_view = { - n_tokens, - batch.token + i, - nullptr, - batch.pos + i, - batch.n_seq_id + i, - batch.seq_id + i, - batch.logits + i, - }; + for (int32_t i = 0; i < batch.n_tokens(); i += n_batch) { + const int32_t n_tokens = std::min(n_batch, batch.n_tokens() - i); + + llama_batch_ext_ptr batch_view(llama_batch_ext_get_view(batch.get(), i, n_tokens)); - const int ret = llama_decode(ctx, batch_view); + const int ret = llama_decode_ext(ctx, batch_view.get()); metrics.on_decoded(slots); if (ret != 0) { @@ -3194,14 +3174,14 @@ struct server_context { if (slot.state == SLOT_STATE_DONE_PROMPT) { if (slot.task_type == SERVER_TASK_TYPE_EMBEDDING) { // prompt evaluated for embedding - send_embedding(slot, batch_view); + send_embedding(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots } if (slot.task_type == SERVER_TASK_TYPE_RERANK) { - send_rerank(slot, batch_view); + send_rerank(slot); slot.release(); slot.i_batch = -1; continue; // continue loop of slots @@ -3298,16 +3278,16 @@ struct server_context { } // construct the speculation batch - common_batch_clear(slot.batch_spec); - common_batch_add (slot.batch_spec, id, slot.n_past, { slot.id }, true); + slot.batch_spec.clear(); + slot.batch_spec.add_text(id, slot.n_past, slot.id, true); for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(slot.batch_spec, draft[i], slot.n_past + 1 + i, { slot.id }, true); + slot.batch_spec.add_text(draft[i], slot.n_past + 1 + i, slot.id, true); } - SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens); + SLT_DBG(slot, "decoding speculative batch, size = %d\n", slot.batch_spec.n_tokens()); - llama_decode(ctx, slot.batch_spec); + llama_decode_ext(ctx, slot.batch_spec.get()); // the accepted tokens from the speculation const auto ids = common_sampler_sample_and_accept_n(slot.smpl, ctx, draft); @@ -4164,6 +4144,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' is not yet supported. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + // for the shape of input/content, see tokenize_input_prompts() json prompt; if (body.count("input") != 0) { @@ -4258,6 +4243,11 @@ int main(int argc, char ** argv) { return; } + if (llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) { + res_error(res, format_error_response("Pooling type 'none' cannot be used with reranking. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST)); + return; + } + const json body = json::parse(req.body); // TODO: implement diff --git a/examples/server/tests/unit/test_embedding.py b/examples/server/tests/unit/test_embedding.py index 8b0eb42b0926f..889a759aea934 100644 --- a/examples/server/tests/unit/test_embedding.py +++ b/examples/server/tests/unit/test_embedding.py @@ -88,13 +88,19 @@ def test_embedding_pooling_none(): res = server.make_request("POST", "/embeddings", data={ "input": "hello hello hello", }) - assert res.status_code == 200 - assert 'embedding' in res.body[0] - assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special - # make sure embedding vector is not normalized - for x in res.body[0]['embedding']: - assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON + # /embeddings does not support pooling type 'none' + assert res.status_code == 400 + assert "error" in res.body + + # TODO: re-enable when we figure out how to support pooling type 'none' + #assert res.status_code == 200 + #assert 'embedding' in res.body[0] + #assert len(res.body[0]['embedding']) == 5 # 3 text tokens + 2 special + + ## make sure embedding vector is not normalized + #for x in res.body[0]['embedding']: + # assert abs(sum([x ** 2 for x in x]) - 1) > EPSILON def test_embedding_pooling_none_oai(): diff --git a/examples/simple-chat/simple-chat.cpp b/examples/simple-chat/simple-chat.cpp index 84f4159737260..1425d2b114438 100644 --- a/examples/simple-chat/simple-chat.cpp +++ b/examples/simple-chat/simple-chat.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include "llama-cpp.h" #include #include #include @@ -108,19 +109,22 @@ int main(int argc, char ** argv) { } // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + llama_pos n_past = 0; + auto batch = llama_batch_ext_ptr::init_from_text(ctx, prompt_tokens.data(), prompt_tokens.size(), n_past, 0, true); + n_past += batch.n_tokens(); + llama_token new_token_id; while (true) { // check if we have enough space in the context to evaluate this batch int n_ctx = llama_n_ctx(ctx); int n_ctx_used = llama_kv_self_used_cells(ctx); - if (n_ctx_used + batch.n_tokens > n_ctx) { + if (n_ctx_used + batch.n_tokens() > n_ctx) { printf("\033[0m\n"); fprintf(stderr, "context size exceeded\n"); exit(0); } - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { GGML_ABORT("failed to decode\n"); } @@ -144,9 +148,13 @@ int main(int argc, char ** argv) { response += piece; // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + batch.clear(); + batch.add_text(new_token_id, n_past, 0, true); + n_past++; } + llama_batch_ext_free(batch.get()); + return response; }; diff --git a/examples/simple/simple.cpp b/examples/simple/simple.cpp index 10e79a0a69eeb..90bae0aa8614b 100644 --- a/examples/simple/simple.cpp +++ b/examples/simple/simple.cpp @@ -1,4 +1,5 @@ #include "llama.h" +#include "llama-cpp.h" #include #include #include @@ -143,7 +144,7 @@ int main(int argc, char ** argv) { // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); + auto batch = llama_batch_ext_ptr::init_from_text(ctx, prompt_tokens.data(), prompt_tokens.size(), 0, 0, true); // main loop @@ -151,14 +152,14 @@ int main(int argc, char ** argv) { int n_decode = 0; llama_token new_token_id; - for (int n_pos = 0; n_pos + batch.n_tokens < n_prompt + n_predict; ) { + for (int n_pos = 0; n_pos + batch.n_tokens() < n_prompt + n_predict; ) { // evaluate the current batch with the transformer model - if (llama_decode(ctx, batch)) { + if (llama_decode_ext(ctx, batch.get())) { fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); return 1; } - n_pos += batch.n_tokens; + n_pos += batch.n_tokens(); // sample the next token { @@ -180,7 +181,8 @@ int main(int argc, char ** argv) { fflush(stdout); // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); + batch.clear(); + batch.add_text(new_token_id, n_pos, 0, true); n_decode += 1; } diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index a5d2bc9d09de7..5981d85304f33 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -4,6 +4,7 @@ #include "speculative.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -113,7 +114,8 @@ int main(int argc, char ** argv) { struct common_sampler * smpl = common_sampler_init(model_tgt, params.sampling); // eval the prompt - llama_decode(ctx_tgt, llama_batch_get_one(inp.data(), inp.size() - 1)); + auto batch = llama_batch_ext_ptr::init_from_text(ctx_tgt, inp.data(), inp.size() - 1, 0, 0, true); + llama_decode_ext(ctx_tgt, batch.get()); // note: keep the last token separate! llama_token id_last = inp.back(); @@ -132,7 +134,7 @@ int main(int argc, char ** argv) { struct common_speculative * spec = common_speculative_init(ctx_dft); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); + llama_batch_ext_ptr batch_tgt(ctx_tgt); const auto t_enc_end = ggml_time_us(); @@ -151,8 +153,8 @@ int main(int argc, char ** argv) { //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); // always have a token to evaluate from before - id_last - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, id_last, n_past++, { 0 }, true); + batch_tgt.clear(); + batch_tgt.add_text(id_last, n_past++, 0, true); // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { @@ -162,12 +164,12 @@ int main(int argc, char ** argv) { } for (size_t i = 0; i < draft.size(); ++i) { - common_batch_add(batch_tgt, draft[i], n_past + i, { 0 }, true); + batch_tgt.add_text(draft[i], n_past + i, 0, true); } //LOG_DBG("target batch: %s\n", string_from(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt.get()); } // sample from the full target batch and return the accepted tokens based on the target sampler diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 627d01bbcb5ad..d61a173408a15 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -3,6 +3,7 @@ #include "sampling.h" #include "log.h" #include "llama.h" +#include "llama-cpp.h" #include #include @@ -45,7 +46,6 @@ int main(int argc, char ** argv) { } common_init(); - if (params.speculative.model.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; @@ -166,9 +166,12 @@ int main(int argc, char ** argv) { const auto t_enc_start = ggml_time_us(); // eval the prompt with both models - llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1)); - llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1)); - llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input)); + auto batch0 = llama_batch_ext_ptr::init_from_text(ctx_tgt, inp.data(), n_input - 1, 0, 0, true); + auto batch1 = llama_batch_ext_ptr::init_from_text(ctx_tgt, &inp.back(), 1, n_input - 1, 0, true); + auto batch2 = llama_batch_ext_ptr::init_from_text(ctx_dft, inp.data(), n_input , 0, 0, true); + llama_decode_ext(ctx_tgt, batch0.get()); + llama_decode_ext(ctx_tgt, batch1.get()); + llama_decode_ext(ctx_dft, batch2.get()); const auto t_enc_end = ggml_time_us(); @@ -199,8 +202,8 @@ int main(int argc, char ** argv) { drafts[s].smpl = common_sampler_init(model_dft, params.sampling); } - llama_batch batch_dft = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); - llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, n_seq_dft); + llama_batch_ext_ptr batch_dft(ctx_dft); + llama_batch_ext_ptr batch_tgt(ctx_tgt); const auto t_dec_start = ggml_time_us(); @@ -441,12 +444,12 @@ int main(int argc, char ** argv) { drafts[0].dists.push_back(std::vector()); drafts[0].i_batch_tgt.push_back(0); - common_batch_clear(batch_dft); - common_batch_add (batch_dft, token_id, n_past_dft, { 0 }, true); + batch_dft.clear(); + batch_dft.add_text(token_id, n_past_dft, 0, true); llama_kv_self_seq_rm(ctx_dft, 0, n_past_dft, -1); // LOG_DBG("dft batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_dft, batch_dft).c_str()); - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft.get()); ++n_past_dft; } @@ -471,12 +474,19 @@ int main(int argc, char ** argv) { drafts[0].drafting = true; drafts[0].i_batch_dft = 0; - common_batch_clear(batch_tgt); - common_batch_add (batch_tgt, drafts[0].tokens[0], n_past_tgt, { 0 }, true); + struct batch_info { + llama_token id; + llama_pos pos; + std::vector seq_id; + }; + + std::vector batch_tgt_data; + + batch_tgt_data.push_back({ drafts[0].tokens[0], n_past_tgt, {0} }); // sample n_draft tokens from the draft model using tree-based sampling for (int i = 0; i < n_draft; ++i) { - batch_dft.n_tokens = 0; + batch_dft.clear(); for (int s = 0; s < n_seq_dft; ++s) { drafts[s].skip = false; @@ -507,11 +517,10 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_dft, s, n_seq_cur, -1, -1); // all previous tokens from this branch are now also part of the new branch - for (int t = 0; t < batch_tgt.n_tokens; ++t) { - for (int p = 0; p < batch_tgt.n_seq_id[t]; ++p) { - if (batch_tgt.seq_id[t][p] == s) { - batch_tgt.seq_id[t][batch_tgt.n_seq_id[t]] = n_seq_cur; - batch_tgt.n_seq_id[t]++; + for (int t = 0; t < (int) batch_tgt_data.size(); ++t) { + for (int p = 0; p < (int) batch_tgt_data[t].seq_id.size(); ++p) { + if (batch_tgt_data[t].seq_id[p] == s) { + batch_tgt_data[t].seq_id.push_back(n_seq_cur); break; } } @@ -553,32 +562,30 @@ int main(int argc, char ** argv) { drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size}); // add unique drafted tokens to the target batch - drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens); + drafts[s].i_batch_tgt.push_back(batch_tgt_data.size()); - common_batch_add(batch_tgt, id, n_past_tgt + i + 1, { s }, true); + batch_tgt_data.push_back({ id, n_past_tgt + i + 1, { s }}); // add the token to the batch for batched decoding with the draft model - drafts[s].i_batch_dft = batch_dft.n_tokens; - - common_batch_add(batch_dft, id, n_past_cur, { s }, true); + drafts[s].i_batch_dft = batch_dft.add_text(id, n_past_cur, s, true); - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { drafts[s].drafting = false; } } } // no sequence is drafting anymore - if (batch_dft.n_tokens == 0) { + if (batch_dft.n_tokens() == 0) { break; } // evaluate the drafted tokens on the draft model - llama_decode(ctx_dft, batch_dft); + llama_decode_ext(ctx_dft, batch_dft.get()); ++n_past_cur; ++n_drafted; - if (batch_tgt.n_tokens > n_draft) { + if (batch_tgt_data.size() > (size_t) n_draft) { break; } } @@ -590,8 +597,15 @@ int main(int argc, char ** argv) { llama_kv_self_seq_cp(ctx_tgt, 0, s, -1, -1); } + batch_tgt.clear(); + for (int i = 0; i < (int) batch_tgt_data.size(); ++i) { + const auto & data = batch_tgt_data[i]; + + batch_tgt.add_text(data.id, data.pos, data.seq_id, true); + } + // LOG_DBG("target batch: %s\n", LOG_BATCH_TOSTR_PRETTY(ctx_tgt, batch_tgt).c_str()); - llama_decode(ctx_tgt, batch_tgt); + llama_decode_ext(ctx_tgt, batch_tgt.get()); ++n_past_tgt; } @@ -634,8 +648,6 @@ int main(int argc, char ** argv) { common_sampler_free(drafts[s].smpl); } - llama_batch_free(batch_dft); - llama_backend_free(); LOG("\n\n"); diff --git a/examples/tts/tts.cpp b/examples/tts/tts.cpp index 4cc42e1674ccc..dac097859fd41 100644 --- a/examples/tts/tts.cpp +++ b/examples/tts/tts.cpp @@ -826,7 +826,7 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // create a llama_batch // we use this object to submit token data for decoding - llama_batch batch = llama_batch_init(std::max(prompt_inp.size(), (size_t) n_parallel), 0, n_parallel); + llama_batch_ext_ptr batch(ctx_ttc); std::vector seq_ids(n_parallel, 0); for (int32_t i = 0; i < n_parallel; ++i) { @@ -835,14 +835,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // evaluate the initial prompt for (size_t i = 0; i < prompt_inp.size(); ++i) { - common_batch_add(batch, prompt_inp[i], i, seq_ids, false); + batch.add_text(prompt_inp[i], i, seq_ids, false); } - GGML_ASSERT(batch.n_tokens == (int) prompt_inp.size()); + GGML_ASSERT(batch.n_tokens() == (int) prompt_inp.size()); // llama_decode will output logits only for the last token of the prompt - batch.logits[batch.n_tokens - 1] = true; + llama_batch_ext_set_output_last(batch.get()); - if (llama_decode(ctx_ttc, batch) != 0) { + if (llama_decode_ext(ctx_ttc, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } @@ -861,16 +861,16 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 // remember the batch index of the last token for each parallel sequence // we need this to determine which logits to sample from - std::vector i_batch(n_parallel, batch.n_tokens - 1); + std::vector i_batch(n_parallel, batch.n_tokens() - 1); - int n_past = batch.n_tokens; + int n_past = batch.n_tokens(); int n_decode = 0; bool next_token_uses_guide_token = true; while (n_decode <= n_predict) { // prepare the next batch - common_batch_clear(batch); + batch.clear(); // sample the next token for each parallel sequence / stream for (int32_t i = 0; i < n_parallel; ++i) { @@ -926,14 +926,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 //LOG_CNT("%d", i); } - i_batch[i] = batch.n_tokens; + i_batch[i] = batch.n_tokens(); // push this new token for next evaluation - common_batch_add(batch, new_token_id, n_past, { i }, true); + batch.add_text(new_token_id, n_past, i, true); } // all streams are finished - if (batch.n_tokens == 0) { + if (batch.n_tokens() == 0) { break; } @@ -941,14 +941,12 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 n_past += 1; // evaluate the current batch with the transformer model - if (llama_decode(ctx_ttc, batch)) { + if (llama_decode_ext(ctx_ttc, batch.get())) { LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1); return 1; } } - llama_batch_free(batch); - LOG("\n"); LOG_INF("%s: time for decoder: %.3f ms\n", __func__, (ggml_time_us() - t_dec_start) / 1000.0f); } @@ -1016,14 +1014,14 @@ lovely<|t_0.56|><|code_start|><|634|><|596|><|1766|><|1556|><|1306|><|1285|><|14 const int n_codes = codes.size(); - llama_batch batch = llama_batch_init(n_codes, 0, 1); + llama_batch_ext_ptr batch(ctx_cts); for (size_t i = 0; i < codes.size(); ++i) { - common_batch_add(batch, codes[i], i, { 0 }, true); // TODO: all logits? + batch.add_text(codes[i], i, 0, true); // TODO: all logits? } - GGML_ASSERT(batch.n_tokens == n_codes); + GGML_ASSERT(batch.n_tokens() == n_codes); - if (llama_decode(ctx_cts, batch) != 0) { + if (llama_decode_ext(ctx_cts, batch.get()) != 0) { LOG_ERR("%s: llama_decode() failed\n", __func__); return 1; } diff --git a/include/llama-cpp.h b/include/llama-cpp.h index 8f6368177de09..5a4d316cbda4b 100644 --- a/include/llama-cpp.h +++ b/include/llama-cpp.h @@ -5,6 +5,7 @@ #endif #include +#include #include "llama.h" @@ -24,7 +25,102 @@ struct llama_adapter_lora_deleter { void operator()(llama_adapter_lora * adapter) { llama_adapter_lora_free(adapter); } }; +struct llama_batch_ext_deleter { + void operator()(llama_batch_ext * batch) { llama_batch_ext_free(batch); } +}; + typedef std::unique_ptr llama_model_ptr; typedef std::unique_ptr llama_context_ptr; typedef std::unique_ptr llama_sampler_ptr; typedef std::unique_ptr llama_adapter_lora_ptr; + +struct llama_batch_ext_ptr : std::unique_ptr { + llama_batch_ext_ptr() : std::unique_ptr() {} + llama_batch_ext_ptr(struct llama_context * ctx) : std::unique_ptr(llama_batch_ext_init(ctx)) {} + llama_batch_ext_ptr(llama_batch_ext * batch) : std::unique_ptr(batch) {} + + // Convenience C++ wrapper to create a batch from text tokens, without worrying about manually freeing it + // First token will be at position pos0 + // The sequence ID will be fixed to seq_id + // If output_last is true, the last token will have output set + static llama_batch_ext_ptr init_from_text(struct llama_context * ctx, + llama_token * tokens, + int32_t n_tokens, + llama_pos pos0, + llama_seq_id seq_id, + bool output_last) { + llama_batch_ext * batch = llama_batch_ext_init(ctx); + for (int32_t i = 0; i < n_tokens; i++) { + llama_batch_ext_add_text(batch, tokens[i], pos0 + i, &seq_id, 1, false); + } + if (output_last) { + // TODO: somehow return the output ID + llama_batch_ext_set_output_last(batch); + } + return llama_batch_ext_ptr(batch); + } + + // Convenience C++ wrapper to create a batch from text embeddings, without worrying about manually freeing it + static llama_batch_ext_ptr init_from_embd(struct llama_context * ctx, + const float * embd, + size_t n_tokens, + size_t n_embd, + const llama_pos * pos, + llama_seq_id seq_id) { + return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(ctx, embd, n_tokens, n_embd, pos, seq_id)); + } + + // Wrapper to create an embeddings batch with starting position pos0, only used when n_pos_per_tokens == 1 + static llama_batch_ext_ptr init_from_embd(struct llama_context * ctx, + const float * embd, + size_t n_tokens, + size_t n_embd, + llama_pos pos0, + llama_seq_id seq_id) { + std::vector pos(n_tokens); + for (size_t i = 0; i < n_tokens; i++) { + pos[i] = pos0 + i; + } + return llama_batch_ext_ptr(llama_batch_ext_init_from_embd(ctx, embd, n_tokens, n_embd, pos.data(), seq_id)); + } + + // Wrapper to add a single token to the batch, support multiple sequence IDs + int32_t add_text(llama_token token, llama_pos pos, const std::vector & seq_id, bool output_last) { + int32_t output_id = -1; + llama_batch_ext_add_text(this->get(), token, pos, seq_id.data(), seq_id.size(), false); + if (output_last) { + output_id = llama_batch_ext_set_output_last(this->get()); + } + return output_id; + } + + // Wrapper to add a single token to the batch (single sequence ID) + int32_t add_text(llama_token token, llama_pos pos, llama_seq_id seq_id, bool output_last) { + int32_t output_id = -1; + llama_batch_ext_add_text(this->get(), token, pos, &seq_id, 1, false); + if (output_last) { + output_id = llama_batch_ext_set_output_last(this->get()); + } + return output_id; + } + + // Return output ID of the last token. Position starts from pos0 + int32_t add_seq(std::vector & tokens, llama_pos pos0, llama_seq_id seq_id, bool output_last) { + int32_t output_id = -1; + for (size_t i = 0; i < tokens.size(); i++) { + llama_batch_ext_add_text(this->get(), tokens[i], pos0 + i, &seq_id, 1, false); + } + if (output_last) { + output_id = llama_batch_ext_set_output_last(this->get()); + } + return output_id; + } + + void clear() { + llama_batch_ext_clear(this->get()); + } + + int32_t n_tokens() const { + return llama_batch_ext_get_n_tokens(this->get()); + } +}; diff --git a/include/llama.h b/include/llama.h index 25a9f82784df2..4b346433854f2 100644 --- a/include/llama.h +++ b/include/llama.h @@ -235,6 +235,9 @@ extern "C" { typedef bool (*llama_progress_callback)(float progress, void * user_data); // Input data for llama_decode + // + // WARN: This struct is DEPRECATED and will be removed in the future, use llama_batch_ext instead + // // A llama_batch object can contain input about one or many sequences // The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens // @@ -258,6 +261,10 @@ extern "C" { int8_t * logits; // TODO: rename this to "output" } llama_batch; + // Input data for llama_decode / llama_encode + // It can contain text tokens and embeddings for one or many sequences + struct llama_batch_ext; + enum llama_model_kv_override_type { LLAMA_KV_OVERRIDE_TYPE_INT, LLAMA_KV_OVERRIDE_TYPE_FLOAT, @@ -476,6 +483,7 @@ extern "C" { LLAMA_API const struct llama_vocab * llama_model_get_vocab(const struct llama_model * model); LLAMA_API enum llama_rope_type llama_model_rope_type(const struct llama_model * model); + LLAMA_API uint32_t llama_n_pos_per_token(const struct llama_model * model); LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); @@ -892,9 +900,9 @@ extern "C" { // // NOTE: this is a helper function to facilitate transition to the new batch API - avoid using it // - LLAMA_API struct llama_batch llama_batch_get_one( + DEPRECATED(LLAMA_API struct llama_batch llama_batch_get_one( llama_token * tokens, - int32_t n_tokens); + int32_t n_tokens), "use llama_batch_ext API instead"); // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens // Each token can be assigned up to n_seq_max sequence ids @@ -903,13 +911,80 @@ extern "C" { // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token // The rest of the llama_batch members are allocated with size n_tokens // All members are left uninitialized - LLAMA_API struct llama_batch llama_batch_init( - int32_t n_tokens, - int32_t embd, - int32_t n_seq_max); + DEPRECATED(LLAMA_API struct llama_batch llama_batch_init( + int32_t n_tokens, + int32_t embd, + int32_t n_seq_max), "use llama_batch_ext_init instead"); // Frees a batch of tokens allocated with llama_batch_init() - LLAMA_API void llama_batch_free(struct llama_batch batch); + DEPRECATED(LLAMA_API void llama_batch_free(struct llama_batch batch), + "use llama_batch_ext API instead"); + + // Allocates a batch of tokens on the heap that can hold a maximum of n_tokens + // Each token can be assigned up to n_seq_max sequence ids + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx); + + // Same with llama_batch_init, but initializes the batch with the provided raw embeddings + // Size of embd should be n_tokens * n_embd + // Size of pos should be n_tokens * n_pos_per_token + // If one token has multiple pos, the pos must follow the order: 000011112222... + // n_embd is the number of embeddings per token, can be obtained from llama_model_n_embd() + // The sequence ID will be fixed to seq_id + // The batch has to be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_init_from_embd( + struct llama_context * ctx, + const float * embd, + size_t n_tokens, + size_t n_embd, + const llama_pos * pos, + llama_seq_id seq_id); + + // Get the number of tokens in the batch + LLAMA_API int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch); + + // Add text tokens to the batch + // Return values: + // -1 : not enough space in the batch + // -2 : embd is already set, cannot add text tokens + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_add_text( + struct llama_batch_ext * batch, + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + bool output); + + // Set output (logits/embeddings) for the token in the ith sequence + // If pos == -1, output will be set for the all tokens + // Return values: + // -1 : the token is not in the batch + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output( + struct llama_batch_ext * batch, + llama_pos pos, + llama_seq_id seq_id); + + // Set output (logits/embeddings) for the last added token + // Return values: + // -1 : the batch is empty + // otherwise, returns the output ID + LLAMA_API int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch); + + // Get a "view" from a number of tokens offset + // Return returned batch must be freed with llama_batch_ext_free() + LLAMA_API struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, + int32_t offset, + int32_t n_tokens); + + // Remove everything from the batch + LLAMA_API void llama_batch_ext_clear(struct llama_batch_ext * batch); + + // Frees a batch of tokens allocated with llama_batch_ext_init() + // If this is a view, the original batch is not freed + LLAMA_API void llama_batch_ext_free(struct llama_batch_ext * batch); // Processes a batch of tokens with the ecoder part of the encoder-decoder model. // Stores the encoder output internally for later use by the decoder cross-attention layers. @@ -919,13 +994,21 @@ extern "C" { struct llama_context * ctx, struct llama_batch batch); + LLAMA_API int32_t llama_encode_ext( + struct llama_context * ctx, + struct llama_batch_ext * batch); + // Positive return values does not mean a fatal error, but rather a warning. // 0 - success // 1 - could not find a KV slot for the batch (try reducing the size of the batch or increase the context) // < 0 - error. the KV cache state is restored to the state before this call LLAMA_API int32_t llama_decode( struct llama_context * ctx, - struct llama_batch batch); + struct llama_batch batch); + + LLAMA_API int32_t llama_decode_ext( + struct llama_context * ctx, + struct llama_batch_ext * batch); // Set the number of threads used for decoding // n_threads is the number of threads used for generation (single token) diff --git a/src/llama-batch.cpp b/src/llama-batch.cpp index 01d5ca57fd82b..d7b232e9c7734 100644 --- a/src/llama-batch.cpp +++ b/src/llama-batch.cpp @@ -1,4 +1,5 @@ #include "llama-batch.h" +#include "llama-graph.h" #include #include @@ -189,7 +190,7 @@ llama_ubatch llama_sbatch::split_seq(size_t n_ubatch) { return ubatch; } -void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool simple_split, bool logits_all) { +void llama_sbatch::from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split, bool logits_all) { GGML_ASSERT(batch.n_tokens >= 0); this->batch = &batch; this->n_embd = n_embd; @@ -273,46 +274,61 @@ void llama_sbatch::from_batch(const llama_batch & batch, size_t n_embd, bool sim ); } -llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0) { - batch = in_batch; - GGML_ASSERT(batch.n_tokens > 0); - if (!batch.pos) { - pos.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { +llama_batch_allocr::llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0) { + batch = new llama_batch_ext{ + /*n_tokens =*/ in_batch.n_tokens, + /*max_tokens =*/ in_batch.n_tokens, + /*n_pos_per_token =*/ 1, + /*is_view =*/ false, + /*tokens =*/ in_batch.token, + /*embd =*/ in_batch.embd, + /*pos =*/ in_batch.pos, + /*n_seq_id =*/ in_batch.n_seq_id, + /*seq_id =*/ in_batch.seq_id, + /*logits =*/ in_batch.logits, + }; + GGML_ASSERT(batch->n_tokens > 0); + if (!in_batch.pos) { + pos.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { pos[i] = i + p0; } - batch.pos = pos.data(); + batch->pos = pos.data(); } - if (!batch.n_seq_id) { - n_seq_id.resize(batch.n_tokens); - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->n_seq_id) { + n_seq_id.resize(batch->n_tokens); + for (int32_t i = 0; i < batch->n_tokens; i++) { n_seq_id[i] = seq_id_0.size(); } - batch.n_seq_id = n_seq_id.data(); + batch->n_seq_id = n_seq_id.data(); } - if (!batch.seq_id) { - seq_id.resize(batch.n_tokens + 1); - seq_id[batch.n_tokens] = NULL; - for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch->seq_id) { + seq_id.resize(batch->n_tokens + 1); + seq_id[batch->n_tokens] = NULL; + for (int32_t i = 0; i < batch->n_tokens; i++) { seq_id[i] = seq_id_0.data(); } - batch.seq_id = seq_id.data(); + batch->seq_id = seq_id.data(); } - if (!batch.logits) { - logits.resize(batch.n_tokens); + if (!batch->logits) { + logits.resize(batch->n_tokens); logits[logits.size() - 1] = true; - batch.logits = logits.data(); + batch->logits = logits.data(); } } +llama_batch_allocr::~llama_batch_allocr() { + delete batch; +} + // // interface implementation // struct llama_batch llama_batch_get_one( - llama_token * tokens, - int32_t n_tokens) { - return { + llama_token * tokens, + int32_t n_tokens) { + return llama_batch{ /*n_tokens =*/ n_tokens, /*tokens =*/ tokens, /*embd =*/ nullptr, @@ -323,6 +339,168 @@ struct llama_batch llama_batch_get_one( }; } +static struct llama_batch_ext * llama_batch_ext_init_impl(int32_t n_tokens_alloc, int32_t n_embd, int32_t n_seq_max, int32_t n_pos_per_token) { + llama_batch_ext * batch = new llama_batch_ext{ + /*n_tokens =*/ 0, + /*max_tokens =*/ n_tokens_alloc, + /*n_pos_per_token =*/ n_pos_per_token, + /*is_view =*/ false, + /*tokens =*/ nullptr, + /*embd =*/ nullptr, + /*pos =*/ nullptr, + /*n_seq_id =*/ nullptr, + /*seq_id =*/ nullptr, + /*logits =*/ nullptr, + }; + + if (n_embd) { + batch->embd = (float *) malloc(sizeof(float) * n_tokens_alloc * n_embd); + } else { + batch->token = (llama_token *) malloc(sizeof(llama_token) * n_tokens_alloc); + } + + batch->pos = (llama_pos *) malloc(sizeof(llama_pos) * n_tokens_alloc * MAX_POS_PER_TOKEN); + batch->n_seq_id = (int32_t *) malloc(sizeof(int32_t) * n_tokens_alloc); + batch->seq_id = (llama_seq_id **) malloc(sizeof(llama_seq_id *) * (n_tokens_alloc + 1)); + for (int i = 0; i < n_tokens_alloc; ++i) { + batch->seq_id[i] = (llama_seq_id *) malloc(sizeof(llama_seq_id) * n_seq_max); + } + batch->seq_id[n_tokens_alloc] = nullptr; + + batch->logits = (int8_t *) malloc(sizeof(int8_t) * n_tokens_alloc); + + return batch; +} + +struct llama_batch_ext * llama_batch_ext_init(struct llama_context * ctx) { + int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx)); + return llama_batch_ext_init_impl(llama_n_batch(ctx), 0, llama_n_seq_max(ctx), n_pos_per_token); +} + +struct llama_batch_ext * llama_batch_ext_init_from_embd( + struct llama_context * ctx, + const float * embd, + size_t n_tokens, + size_t n_embd, + const llama_pos * pos, + llama_seq_id seq_id) { + int32_t n_pos_per_token = llama_n_pos_per_token(llama_get_model(ctx)); + struct llama_batch_ext * batch = llama_batch_ext_init_impl(n_tokens, n_embd, 1, n_pos_per_token); + memcpy(batch->embd, embd, n_tokens * n_embd * sizeof(float)); + memcpy(batch->pos, pos, n_tokens * n_pos_per_token * sizeof(llama_pos)); + for (size_t i = 0; i < n_tokens; i++) { + batch->n_seq_id[i] = 1; + batch->seq_id [i][0] = seq_id; + } + return batch; +} + +int32_t llama_batch_ext_get_n_tokens(const struct llama_batch_ext * batch) { + return batch->n_tokens; +} + +int32_t llama_batch_ext_add_text( + struct llama_batch_ext * batch, + llama_token token, + llama_pos pos, + const llama_seq_id * seq_ids, + size_t n_seq_ids, + bool output) { + if (batch->n_tokens + 1 > batch->max_tokens) { + return -1; // llama_batch size exceeded + } + if (batch->embd) { + return -2; // embd is already set, cannot add text tokens + } + const int32_t output_id = batch->n_tokens; + batch->token [output_id] = token; + batch->n_seq_id[output_id] = n_seq_ids; + batch->logits [output_id] = output; + for (int32_t i = 0; i < batch->n_pos_per_token; i++) { + // TODO: this is only used by qwen2vl for now, and text tokens only have 3 pos, the last is set to 0; we should improve this code in the future + batch->pos[output_id * batch->n_pos_per_token + i] = i < 3 ? pos : 0; + } + batch->n_seq_id[output_id] = n_seq_ids; + for (size_t j = 0; j < n_seq_ids; j++) { + batch->seq_id[batch->n_tokens][j] = seq_ids[j]; + } + batch->n_tokens++; + return output_id; +} + +int32_t llama_batch_ext_set_output( + struct llama_batch_ext * batch, + llama_pos pos, + llama_seq_id seq_id) { + for (int32_t i = 0; i < batch->n_tokens; i++) { + // find the token having seq_id + for (int32_t j = 0; j < batch->n_seq_id[i]; j++) { + if (batch->seq_id[i][j] == seq_id) { + // found the sequence + if (pos == -1 || pos == batch->pos[i]) { + batch->logits[i] = true; + return i; + } + } + } + } + return -1; // not found +} + +int32_t llama_batch_ext_set_output_last(struct llama_batch_ext * batch) { + if (batch->n_tokens == 0) { + return -1; + } + const int32_t output_id = batch->n_tokens - 1; + batch->logits[output_id] = true; + return output_id; +} + +void llama_batch_ext_clear(struct llama_batch_ext * batch) { + batch->n_tokens = 0; +} + +struct llama_batch_ext * llama_batch_ext_get_view( + struct llama_batch_ext * batch, + int32_t offset, + int32_t n_tokens) { + if (batch->embd) { + return nullptr; // not yet supported + } + llama_batch_ext * batch_view = new llama_batch_ext{ + /*n_tokens =*/ n_tokens, + /*max_tokens =*/ n_tokens, + /*n_pos_per_token =*/ batch->n_pos_per_token, + /*is_view =*/ true, + /*tokens =*/ batch->token + offset, + /*embd =*/ nullptr, + /*pos =*/ batch->pos + offset * batch->n_pos_per_token, + /*n_seq_id =*/ batch->n_seq_id + offset, + /*seq_id =*/ batch->seq_id + offset, + /*logits =*/ batch->logits + offset, + }; + return batch_view; +} + +void llama_batch_ext_free(struct llama_batch_ext * batch) { + // do not free the members if it's a view + if (!batch->is_view) { + if (batch->token) free(batch->token); + if (batch->embd) free(batch->embd); + if (batch->pos) free(batch->pos); + if (batch->n_seq_id) free(batch->n_seq_id); + if (batch->seq_id) { + for (int i = 0; batch->seq_id[i] != nullptr; ++i) { + free(batch->seq_id[i]); + } + free(batch->seq_id); + } + if (batch->logits) free(batch->logits); + } + delete batch; +} + +// deprecated struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_t n_seq_max) { llama_batch batch = { /*n_tokens =*/ 0, @@ -353,6 +531,7 @@ struct llama_batch llama_batch_init(int32_t n_tokens_alloc, int32_t embd, int32_ return batch; } +// deprecated void llama_batch_free(struct llama_batch batch) { if (batch.token) free(batch.token); if (batch.embd) free(batch.embd); diff --git a/src/llama-batch.h b/src/llama-batch.h index f1df40d27086e..6671cdcd76df1 100644 --- a/src/llama-batch.h +++ b/src/llama-batch.h @@ -5,6 +5,33 @@ #include #include +// Input data for llama_decode / llama_encode +// A llama_batch_ext object can contain input about one or many sequences +// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens +// +// - token : the token ids of the input (used when embd is NULL) +// - embd : token embeddings (i.e. float vector of size n_embd) (used when token is NULL) +// - pos : the positions of the respective token in the sequence +// (if set to NULL, the token position will be tracked automatically by llama_decode) +// - seq_id : the sequence to which the respective token belongs +// (if set to NULL, the sequence ID will be assumed to be 0) +// - logits : if zero, the logits (and/or the embeddings) for the respective token will not be output +// (if set to NULL, only the logits for last token will be returned) +// +struct llama_batch_ext { + int32_t n_tokens; + int32_t max_tokens; + int32_t n_pos_per_token = 1; + bool is_view; + + llama_token * token; + float * embd; + llama_pos * pos; // if multi pos per token: 000011112222... + int32_t * n_seq_id; + llama_seq_id ** seq_id; + int8_t * logits; // TODO: rename this to "output" +}; + // very similar to llama_batch, // but has more metadata about sequences struct llama_ubatch { @@ -47,7 +74,7 @@ struct llama_sbatch { std::vector out_ids; std::vector seq; - const llama_batch * batch = nullptr; + const llama_batch_ext * batch = nullptr; // buffers for the ubatch std::vector ubatch_token; @@ -70,12 +97,12 @@ struct llama_sbatch { // sequence-wise split llama_ubatch split_seq(size_t n_ubatch); - void from_batch(const llama_batch & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); + void from_batch(const llama_batch_ext & batch, size_t n_embd, bool simple_split = false, bool logits_all = false); }; // temporary allocate memory for the input batch if needed struct llama_batch_allocr { - struct llama_batch batch; + struct llama_batch_ext * batch; std::array seq_id_0 = { 0 }; // default sequence id std::vector pos; @@ -84,5 +111,7 @@ struct llama_batch_allocr { std::vector logits; // optionally fulfill the batch returned by llama_batch_get_one - llama_batch_allocr(struct llama_batch in_batch, llama_pos p0); + llama_batch_allocr(struct llama_batch & in_batch, llama_pos p0); + + ~llama_batch_allocr(); }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index aa363df6356ea..c23854138ccf8 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -4,6 +4,7 @@ #include "llama-io.h" #include "llama-mmap.h" #include "llama-model.h" +#include "llama-batch.h" #include "llama-kv-cache.h" #include @@ -1017,16 +1018,26 @@ bool llama_context::apply_adapter_cvec( } int llama_context::encode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return encode(*batch_allocr.batch); +} + +int llama_context::decode(llama_batch & inp_batch) { + // temporary allocate memory and convert llama_batch to llama_batch_ext + // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences + llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); + return decode(*batch_allocr.batch); +} + +int llama_context::encode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const int32_t n_tokens = batch.n_tokens; const auto & hparams = model.hparams; @@ -1181,17 +1192,13 @@ int llama_context::encode(llama_batch & inp_batch) { return 0; } -int llama_context::decode(llama_batch & inp_batch) { +int llama_context::decode(llama_batch_ext & inp_batch) { if (inp_batch.n_tokens == 0) { LLAMA_LOG_ERROR("%s: n_tokens == 0\n", __func__); return -1; } - // temporary allocate memory for the input batch if needed - // TODO: this is incorrect for multiple sequences because pos_max() is the maximum across all sequences - llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->pos_max() + 1); - - const llama_batch & batch = batch_allocr.batch; + llama_batch_ext & batch = inp_batch; const auto & vocab = model.vocab; const auto & hparams = model.hparams; @@ -2767,26 +2774,30 @@ size_t llama_state_seq_load_file(llama_context * ctx, const char * filepath, lla /// +// deprecated int32_t llama_encode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->encode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to encode, ret = %d\n", __func__, ret); - } - - return ret; + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->encode(inp_batch); } +// deprecated int32_t llama_decode( - llama_context * ctx, - llama_batch batch) { - const int ret = ctx->decode(batch); - if (ret != 0) { - LLAMA_LOG_ERROR("%s: failed to decode, ret = %d\n", __func__, ret); - } + struct llama_context * ctx, + struct llama_batch inp_batch) { + return ctx->decode(inp_batch); +} + +int32_t llama_encode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->encode(*inp_batch); +} - return ret; +int32_t llama_decode_ext( + struct llama_context * ctx, + struct llama_batch_ext * inp_batch) { + return ctx->decode(*inp_batch); } // diff --git a/src/llama-context.h b/src/llama-context.h index 04facb544cb1a..fff1b9e537b76 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -82,9 +82,13 @@ struct llama_context { int32_t il_start, int32_t il_end); + // deprecated int encode(llama_batch & inp_batch); int decode(llama_batch & inp_batch); + int encode(llama_batch_ext & inp_batch); + int decode(llama_batch_ext & inp_batch); + // // state save/load // diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0bd40174438cc..477120ce8ac5c 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2,6 +2,7 @@ #include "llama-impl.h" #include "llama-batch.h" +#include "llama-model.h" #include "llama-cparams.h" #include "llama-kv-cache.h" @@ -565,6 +566,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : hparams (params.hparams), cparams (params.cparams), ubatch (params.ubatch), + n_pos_per_token (llama_n_pos_per_token(params.arch)), n_embd (hparams.n_embd), n_layer (hparams.n_layer), n_rot (hparams.n_rot), @@ -602,10 +604,6 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : res (std::make_unique()) { } -int64_t llm_graph_context::n_pos_per_token() const { - return arch == LLM_ARCH_QWEN2VL ? 4 : 1; -} - void llm_graph_context::cb(ggml_tensor * cur, const char * name, int il) const { if (cb_func) { cb_func(ubatch, cur, name, il); @@ -1003,11 +1001,11 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { } ggml_tensor * llm_graph_context::build_inp_pos() const { - auto inp = std::make_unique(n_pos_per_token()); + auto inp = std::make_unique(n_pos_per_token); auto & cur = inp->pos; - cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token()); + cur = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens*n_pos_per_token); ggml_set_input(cur); res->add_input(std::move(inp)); diff --git a/src/llama-graph.h b/src/llama-graph.h index bdf19ed015e35..adcca07fe81c7 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,8 @@ #include #include +#define MAX_POS_PER_TOKEN 4 + struct ggml_cgraph; struct ggml_context; struct ggml_tensor; @@ -355,6 +357,7 @@ struct llm_graph_context { const llama_cparams & cparams; const llama_ubatch & ubatch; + const int64_t n_pos_per_token; const int64_t n_embd; const int64_t n_layer; const int64_t n_rot; @@ -402,8 +405,6 @@ struct llm_graph_context { llm_graph_context(const llm_graph_params & params); - int64_t n_pos_per_token() const; - void cb(ggml_tensor * cur, const char * name, int il) const; // diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 0ae754154b069..54697c3bb908c 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -6075,6 +6075,11 @@ struct llm_build_qwen2vl : public llm_graph_context { // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); + // TODO @ngxson : transpose layout 0000111122223333 to 0123012301230123, we should improve this in the future + inp_pos = ggml_reshape_2d(ctx0, inp_pos, n_tokens, n_pos_per_token); + inp_pos = ggml_cont(ctx0, ggml_transpose(ctx0, inp_pos)); + inp_pos = ggml_reshape_1d(ctx0, inp_pos, n_pos_per_token * n_tokens); + auto * inp_attn = build_attn_inp_kv_unified(); int sections[4]; @@ -12057,6 +12062,14 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_NONE; } +uint32_t llama_n_pos_per_token(llm_arch arch) { + return arch == LLM_ARCH_QWEN2VL ? 4 : 1; +} + +uint32_t llama_n_pos_per_token(const struct llama_model * model) { + return llama_n_pos_per_token(model->arch); +} + float llama_model_rope_freq_scale_train(const llama_model * model) { return model->hparams.rope_freq_scale_train; } diff --git a/src/llama-model.h b/src/llama-model.h index a9da1215abbfd..7320201de7a11 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -401,3 +401,5 @@ const char * llm_type_name(llm_type type); // For internal test use // TODO: remove const std::vector> & llama_internal_get_tensor_map(const llama_model * model); + +uint32_t llama_n_pos_per_token(llm_arch arch);