Skip to content

llama : add llama_batch_ext #11875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 61 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
61 commits
Select commit Hold shift + click to select a range
4ed4fe7
first proposal for private llama_batch
ngxson Feb 13, 2025
f2e59a8
rework, targeting llama-server
ngxson Feb 14, 2025
17d3658
move to llama_batch_ext
ngxson Feb 15, 2025
85ef80c
server : use llama_batch_ext
ngxson Feb 15, 2025
aed4a8e
fix server
ngxson Feb 16, 2025
4bf7ca3
llama_decode_ext
ngxson Feb 24, 2025
a1b1dea
Merge branch 'master' into xsn/private_batch_api
ngxson Feb 24, 2025
f0ffd81
adapt common
ngxson Mar 1, 2025
9e75c49
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 1, 2025
40989f4
correct llama_decode_ext
ngxson Mar 1, 2025
1170135
llama_batch_ext_add_text
ngxson Mar 1, 2025
1d6ba97
remove token_info API
ngxson Mar 1, 2025
46596ca
apply various in places
ngxson Mar 1, 2025
17f954c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 13, 2025
86973cb
fix merge errors
ngxson Mar 13, 2025
4aabf4e
return output ID from llama_batch_ext_add/set
ngxson Mar 13, 2025
47086fa
apply to the rest
ngxson Mar 13, 2025
9fb2d81
fix common_batch missing seq_id
ngxson Mar 13, 2025
65f0184
compile ok
ngxson Mar 13, 2025
c3dd790
fix llama_batch_ext_init_from_text
ngxson Mar 13, 2025
04f8641
rm redundant llama_batch_ext_set_output_last
ngxson Mar 13, 2025
54566ad
correct comment
ngxson Mar 13, 2025
bfdddbc
bring back mistakenly deleted llama_batch_init/free
ngxson Mar 13, 2025
5e6a6d4
fix llama-run n_past
ngxson Mar 14, 2025
3294036
fix gemma3-cli
ngxson Mar 14, 2025
07d84fa
fix missing n_past in various places
ngxson Mar 14, 2025
ba79369
fix llama_batch_ext_init_from_embd
ngxson Mar 14, 2025
a363251
qwen2vl: use llama_batch_ext_set_pos
ngxson Mar 14, 2025
8e7714f
fix compile
ngxson Mar 14, 2025
eaffba0
llama_batch_ext_ptr::from_text/embd
ngxson Mar 14, 2025
116b9a1
rename to init_from_text
ngxson Mar 14, 2025
624a683
fix compile
ngxson Mar 14, 2025
de788e0
Update examples/tts/tts.cpp
ngxson Mar 17, 2025
eab5606
Apply suggestions from code review
ngxson Mar 17, 2025
dc4bb64
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 18, 2025
7a3c178
speculative : adapt to new llama API
ggerganov Mar 18, 2025
23d7407
Merge pull request #15 from ggml-org/xsn/private_batch_api
ngxson Mar 19, 2025
b0db7fc
android : adapt to new API
ggerganov Mar 19, 2025
96ca6e8
swift : adapt to new API
ggerganov Mar 19, 2025
32c2c41
android : fix permission
ngxson Mar 19, 2025
6f54ee6
retrieval : avoid common_batch
ggerganov Mar 19, 2025
8b80d68
embedding : avoid common_batch
ggerganov Mar 19, 2025
76fd7d6
perplexity : avoid common_batch
ggerganov Mar 20, 2025
8a23b4a
server : avoid common_batch
ggerganov Mar 20, 2025
b8b1732
server : remove old commented code [no ci]
ggerganov Mar 20, 2025
bd51d63
Merge pull request #16 from ggml-org/xsn/private_batch_api_pooling_none
ngxson Mar 20, 2025
30f1db9
remove C API llama_batch_ext_init_from_text
ngxson Mar 20, 2025
c5a0176
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 21, 2025
2134cab
add cpp batch.add_text wrapper
ngxson Mar 21, 2025
2cec1cf
move various places to batch.add_text
ngxson Mar 21, 2025
3802ff2
add batch.clear() and batch.n_tokens()
ngxson Mar 21, 2025
e8827a6
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 23, 2025
a9efdbb
qwen2vl: fix mrope position
ngxson Mar 23, 2025
1434c2c
Merge branch 'master' into xsn/private_batch_api
ngxson Mar 25, 2025
d18a79e
llama_batch_ext_init with ctx
ngxson Mar 25, 2025
c4fea7f
fix qwzn2vl mrope position input
ngxson Mar 25, 2025
42062cc
fix build
ngxson Mar 25, 2025
56e82d0
fix server
ngxson Mar 25, 2025
50fb396
server: fix batch_spec
ngxson Mar 25, 2025
8ec0ff9
fix embeddings and retrieval
ngxson Mar 27, 2025
c1f4a78
correct output_id for llama-cpp header
ngxson Mar 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 4 additions & 64 deletions common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,41 +582,6 @@ std::string string_from(const struct llama_context * ctx, const std::vector<llam
return buf.str();
}

std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch) {
std::stringstream buf;

buf << "[ ";

bool first = true;
for (int i = 0; i < batch.n_tokens; ++i) {
if (!first) {
buf << ", ";
} else {
first = false;
}

auto detokenized = common_token_to_piece(ctx, batch.token[i]);

detokenized.erase(
std::remove_if(
detokenized.begin(),
detokenized.end(),
[](const unsigned char c) { return !std::isprint(c); }),
detokenized.end());

buf << "\n" << std::to_string(i)
<< ", token '" << detokenized << "'"
<< ", pos " << std::to_string(batch.pos[i])
<< ", n_seq_id " << std::to_string(batch.n_seq_id[i])
<< ", seq_id " << std::to_string(batch.seq_id[i][0])
<< ", logits " << std::to_string(batch.logits[i]);
}

buf << " ]";

return buf.str();
}

void string_process_escapes(std::string & input) {
std::size_t input_len = input.length();
std::size_t output_idx = 0;
Expand Down Expand Up @@ -1051,7 +1016,8 @@ struct common_init_result common_init_from_params(common_params & params) {
}

if (llama_model_has_encoder(model)) {
llama_encode(lctx, llama_batch_get_one(tmp.data(), tmp.size()));
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), tmp.size(), 0, 0, true);
llama_encode_ext(lctx, batch.get());
llama_token decoder_start_token_id = llama_model_decoder_start_token(model);
if (decoder_start_token_id == LLAMA_TOKEN_NULL) {
decoder_start_token_id = bos;
Expand All @@ -1060,7 +1026,8 @@ struct common_init_result common_init_from_params(common_params & params) {
tmp.push_back(decoder_start_token_id);
}
if (llama_model_has_decoder(model)) {
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
auto batch = llama_batch_ext_ptr::init_from_text(lctx, tmp.data(), std::min(tmp.size(), (size_t) params.n_batch), 0, 0, true);
llama_decode_ext(lctx, batch.get());
}
llama_kv_self_clear(lctx);
llama_synchronize(lctx);
Expand Down Expand Up @@ -1609,33 +1576,6 @@ std::pair<std::string, std::string> common_get_hf_file(const std::string &, cons

#endif // LLAMA_USE_CURL

//
// Batch utils
//

void common_batch_clear(struct llama_batch & batch) {
batch.n_tokens = 0;
}

void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits) {
GGML_ASSERT(batch.seq_id[batch.n_tokens] && "llama_batch size exceeded");

batch.token [batch.n_tokens] = id;
batch.pos [batch.n_tokens] = pos;
batch.n_seq_id[batch.n_tokens] = seq_ids.size();
for (size_t i = 0; i < seq_ids.size(); ++i) {
batch.seq_id[batch.n_tokens][i] = seq_ids[i];
}
batch.logits [batch.n_tokens] = logits;

batch.n_tokens++;
}

//
// Token utils
//
Expand Down
14 changes: 0 additions & 14 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,6 @@ void string_process_escapes(std::string & input);
std::string string_from(bool value);
std::string string_from(const std::vector<int> & values);
std::string string_from(const struct llama_context * ctx, const std::vector<llama_token> & tokens);
std::string string_from(const struct llama_context * ctx, const struct llama_batch & batch);

//
// Filesystem utils
Expand Down Expand Up @@ -566,19 +565,6 @@ std::pair<std::string, std::string> common_get_hf_file(
// clear LoRA adapters from context, then apply new list of adapters
void common_set_adapter_lora(struct llama_context * ctx, std::vector<common_adapter_lora_info> & lora);

//
// Batch utils
//

void common_batch_clear(struct llama_batch & batch);

void common_batch_add(
struct llama_batch & batch,
llama_token id,
llama_pos pos,
const std::vector<llama_seq_id> & seq_ids,
bool logits);

//
// Token utils
//
Expand Down
26 changes: 12 additions & 14 deletions common/speculative.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ struct common_speculative {
struct llama_context * ctx;
struct common_sampler * smpl;

llama_batch batch;
llama_batch_ext_ptr batch;
llama_tokens prompt;
};

Expand All @@ -23,7 +23,7 @@ struct common_speculative * common_speculative_init(
auto * result = new common_speculative {
/* .ctx = */ ctx_dft,
/* .smpl = */ nullptr,
/* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1),
/* .batch = */ llama_batch_ext_ptr(ctx_dft),
/* .prompt = */ {},
};

Expand Down Expand Up @@ -69,8 +69,6 @@ void common_speculative_free(struct common_speculative * spec) {

common_sampler_free(spec->smpl);

llama_batch_free(spec->batch);

delete spec;
}

Expand Down Expand Up @@ -206,40 +204,40 @@ llama_tokens common_speculative_gen_draft(
}

// prepare a batch to evaluate any new tokens in the prompt
common_batch_clear(batch);
batch.clear();

for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) {
//LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]);
common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false);
batch.add_text(prompt_tgt[i], i - i_start, 0, false);

prompt.push_back(prompt_tgt[i]);
}

// we should rarely end-up here during normal decoding
if (batch.n_tokens > 0) {
if (batch.n_tokens() > 0) {
//LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str());

llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());
}

const llama_pos n_past = prompt.size();

LOG_DBG("%s: n_past = %d\n", __func__, n_past);

common_batch_clear(batch);
common_batch_add (batch, id_last, n_past, { 0 }, true);
batch.clear();
batch.add_text(id_last, n_past, 0, true);

prompt.push_back(id_last);

//LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx, prompt).c_str());

llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());

common_sampler_reset(smpl);

// sample n_draft tokens from the draft model
for (int i = 0; i < params.n_draft; ++i) {
common_batch_clear(batch);
batch.clear();

common_sampler_sample(smpl, ctx, 0, true);

Expand All @@ -266,10 +264,10 @@ llama_tokens common_speculative_gen_draft(
break;
}

common_batch_add(batch, id, n_past + i + 1, { 0 }, true);
batch.add_text( id, n_past + i + 1, 0, true);

// evaluate the drafted tokens on the draft model
llama_decode(ctx, batch);
llama_decode_ext(ctx, batch.get());

prompt.push_back(id);
}
Expand Down
40 changes: 17 additions & 23 deletions examples/batched-bench/batched-bench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,24 +59,17 @@ int main(int argc, char ** argv) {

const int32_t n_kv_max = llama_n_ctx(ctx);

llama_batch batch = llama_batch_init(n_kv_max, 0, 1);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

// decode in batches of ctx_params.n_batch tokens
auto decode_helper = [](llama_context * ctx, llama_batch & batch, int32_t n_batch) {
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));

llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.n_seq_id + i,
batch.seq_id + i,
batch.logits + i,
};

const int ret = llama_decode(ctx, batch_view);
auto decode_helper = [](llama_context * ctx, llama_batch_ext * batch, int32_t n_batch) {
const int32_t n_batch_tokens = llama_batch_ext_get_n_tokens(batch);
for (int32_t i = 0; i < (int32_t) n_batch_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (n_batch_tokens - i));

llama_batch_ext_ptr batch_view = llama_batch_ext_ptr(llama_batch_ext_get_view(batch, i, n_tokens));

const int ret = llama_decode_ext(ctx, batch_view.get());
if (ret != 0) {
LOG_ERR("failed to decode the batch, n_batch = %d, ret = %d\n", n_batch, ret);
return false;
Expand All @@ -91,7 +84,8 @@ int main(int argc, char ** argv) {
// warm up
{
for (int i = 0; i < 16; ++i) {
common_batch_add(batch, 0, i, { 0 }, false);
const llama_seq_id seq_id = 0;
llama_batch_ext_add_text(batch, 0, i, &seq_id, 1, false);
}

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
Expand Down Expand Up @@ -121,14 +115,14 @@ int main(int argc, char ** argv) {
continue;
}

common_batch_clear(batch);
llama_batch_ext_clear(batch);

for (int i = 0; i < pp; ++i) {
for (int j = 0; j < (is_pp_shared ? 1 : pl); ++j) {
common_batch_add(batch, 0, i, { j }, false);
llama_batch_ext_add_text(batch, 0, i, &j, 1, false);
}
}
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);

const auto t_pp_start = ggml_time_us();

Expand All @@ -150,10 +144,10 @@ int main(int argc, char ** argv) {
const auto t_tg_start = ggml_time_us();

for (int i = 0; i < tg; ++i) {
common_batch_clear(batch);
llama_batch_ext_clear(batch);

for (int j = 0; j < pl; ++j) {
common_batch_add(batch, 0, pp + i, { j }, true);
llama_batch_ext_add_text(batch, 0, pp + i, &j, 1, true);
}

if (!decode_helper(ctx, batch, ctx_params.n_batch)) {
Expand Down Expand Up @@ -191,7 +185,7 @@ int main(int argc, char ** argv) {
LOG("\n");
llama_perf_context_print(ctx);

llama_batch_free(batch);
llama_batch_ext_free(batch);

llama_free(ctx);
llama_model_free(model);
Expand Down
32 changes: 16 additions & 16 deletions examples/batched/batched.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ int main(int argc, char ** argv) {

// create a llama_batch
// we use this object to submit token data for decoding
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t) n_parallel), 0, n_parallel);
llama_batch_ext * batch = llama_batch_ext_init(ctx);

std::vector<llama_seq_id> seq_ids(n_parallel, 0);
for (int32_t i = 0; i < n_parallel; ++i) {
Expand All @@ -111,12 +111,12 @@ int main(int argc, char ** argv) {

// evaluate the initial prompt
for (size_t i = 0; i < tokens_list.size(); ++i) {
common_batch_add(batch, tokens_list[i], i, seq_ids, false);
llama_batch_ext_add_text(batch, tokens_list[i], i, seq_ids.data(), seq_ids.size(), false);
}
GGML_ASSERT(batch.n_tokens == (int) tokens_list.size());
GGML_ASSERT(llama_batch_ext_get_n_tokens(batch) == (int) tokens_list.size());

if (llama_model_has_encoder(model)) {
if (llama_encode(ctx, batch)) {
if (llama_encode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval\n", __func__);
return 1;
}
Expand All @@ -126,14 +126,14 @@ int main(int argc, char ** argv) {
decoder_start_token_id = llama_vocab_bos(vocab);
}

common_batch_clear(batch);
common_batch_add(batch, decoder_start_token_id, 0, seq_ids, false);
llama_batch_ext_clear(batch);
llama_batch_ext_add_text(batch, decoder_start_token_id, 0, seq_ids.data(), seq_ids.size(), false);
}

// llama_decode will output logits only for the last token of the prompt
batch.logits[batch.n_tokens - 1] = true;
llama_batch_ext_set_output_last(batch);

if (llama_decode(ctx, batch) != 0) {
if (llama_decode_ext(ctx, batch) != 0) {
LOG_ERR("%s: llama_decode() failed\n", __func__);
return 1;
}
Expand All @@ -155,16 +155,16 @@ int main(int argc, char ** argv) {

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
std::vector<int32_t> i_batch(n_parallel, llama_batch_ext_get_n_tokens(batch) - 1);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the next step along this refactoring is to remove all usages of i_batch from the examples. The i_batch is the index we use to extract logits from the ith token in the batch, but this pattern is quite cumbersome and not very intuitive. In order to avoid this pattern, we have to introduce a new API call for sampling a token from a sequence:

    LLAMA_API llama_token llama_sampler_sample_seq(struct llama_sampler * smpl, struct llama_context * ctx, llama_seq_id seq_id);

This should be enough to replace most or all usages of i_batch. We can do this in a next PR.


int n_cur = batch.n_tokens;
int n_cur = llama_batch_ext_get_n_tokens(batch);
int n_decode = 0;

const auto t_main_start = ggml_time_us();

while (n_cur <= n_predict) {
// prepare the next batch
common_batch_clear(batch);
llama_batch_ext_clear(batch);

// sample the next token for each parallel sequence / stream
for (int32_t i = 0; i < n_parallel; ++i) {
Expand Down Expand Up @@ -193,23 +193,23 @@ int main(int argc, char ** argv) {

streams[i] += common_token_to_piece(ctx, new_token_id);

i_batch[i] = batch.n_tokens;
i_batch[i] = llama_batch_ext_get_n_tokens(batch);

// push this new token for next evaluation
common_batch_add(batch, new_token_id, n_cur, { i }, true);
llama_batch_ext_add_text(batch, new_token_id, n_cur, &i, 1, true);

n_decode += 1;
}

// all streams are finished
if (batch.n_tokens == 0) {
if (llama_batch_ext_get_n_tokens(batch) == 0) {
break;
}

n_cur += 1;

// evaluate the current batch with the transformer model
if (llama_decode(ctx, batch)) {
if (llama_decode_ext(ctx, batch)) {
LOG_ERR("%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
Expand All @@ -234,7 +234,7 @@ int main(int argc, char ** argv) {

fprintf(stderr, "\n");

llama_batch_free(batch);
llama_batch_ext_free(batch);

llama_sampler_free(smpl);
llama_free(ctx);
Expand Down
Loading
Loading