Skip to content

kv-cache : simplify the interface #13660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/simple-chat/simple-chat.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
auto generate = [&](const std::string & prompt) {
std::string response;

const bool is_first = llama_kv_self_used_cells(ctx) == 0;
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;

// tokenize the prompt
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
Expand All @@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
while (true) {
// check if we have enough space in the context to evaluate this batch
int n_ctx = llama_n_ctx(ctx);
int n_ctx_used = llama_kv_self_used_cells(ctx);
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
if (n_ctx_used + batch.n_tokens > n_ctx) {
printf("\033[0m\n");
fprintf(stderr, "context size exceeded\n");
Expand Down
6 changes: 4 additions & 2 deletions include/llama.h
Original file line number Diff line number Diff line change
Expand Up @@ -610,10 +610,12 @@ extern "C" {

// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");

// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
"Use llama_kv_self_seq_pos_max() instead");

// Clear the KV cache - both cell info is erased and KV data is zeroed
LLAMA_API void llama_kv_self_clear(
Expand Down
4 changes: 3 additions & 1 deletion src/llama-batch.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "llama-batch.h"

#include <cassert>
#include <cstring>
#include <algorithm>

Expand Down Expand Up @@ -281,9 +282,10 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
batch = in_batch;
GGML_ASSERT(batch.n_tokens > 0);
if (!batch.pos) {
assert(p0 >= 0);
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i + p0;
pos[i] = p0 + i;
}
batch.pos = pos.data();
}
Expand Down
39 changes: 35 additions & 4 deletions src/llama-context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) {
return -1;
}

if (!inp_batch.pos) {
if (inp_batch.seq_id) {
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
return -1;
}
}

llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());

// temporary allocate memory for the input batch if needed
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0) + 1);

const llama_batch & batch = batch_allocr.batch;

Expand Down Expand Up @@ -2292,22 +2298,47 @@ int32_t llama_apply_adapter_cvec(
// kv cache
//

// deprecated
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}

return kv->get_n_tokens();
int32_t res = 0;

for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);

if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}

return res;
}

// deprecated
// note: this is the same as above - will be removed anyway, so it's ok
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
const auto * kv = ctx->get_kv_self();
if (!kv) {
return 0;
}

return kv->get_used_cells();
int32_t res = 0;

for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
const llama_pos p0 = kv->seq_pos_min(s);
const llama_pos p1 = kv->seq_pos_max(s);

if (p0 >= 0) {
res += (p1 - p0) + 1;
}
}

return res;
}

void llama_kv_self_clear(llama_context * ctx) {
Expand Down
111 changes: 24 additions & 87 deletions src/llama-kv-cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
bool v_trans,
bool offload,
uint32_t kv_size,
uint32_t padding,
uint32_t n_seq_max,
uint32_t n_pad,
uint32_t n_swa,
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
llama_swa_type swa_type) :
model(model), hparams(model.hparams), v_trans(v_trans),
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {

this->type_k = type_k;
this->type_v = type_v;
GGML_ASSERT(kv_size % n_pad == 0);

// create a context for each buffer type
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
Expand Down Expand Up @@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
const size_t memory_size_k = size_k_bytes();
const size_t memory_size_v = size_v_bytes();

LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(),
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
}
Expand Down Expand Up @@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
void llama_kv_cache_unified::defrag_sched(float thold) {
// - do not defrag small contexts (i.e. < 2048 tokens)
// - count the padding towards the number of used tokens
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;

// queue defragmentation for next llama_kv_cache_update
if (fragmentation > thold) {
Expand Down Expand Up @@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
// a heuristic, to avoid attending the full cache if it is not yet utilized
// after enough generations, the benefit from this heuristic disappears
// if we start defragmenting the cache, the benefit from this will be more important
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));

#ifdef FIND_SLOT_DEBUG
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
Expand All @@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
return true;
}

int32_t llama_kv_cache_unified::get_n_tokens() const {
int32_t result = 0;

for (uint32_t i = 0; i < size; i++) {
result += cells[i].seq_id.size();
}

return result;
}

int32_t llama_kv_cache_unified::get_used_cells() const {
return used;
}

bool llama_kv_cache_unified::get_can_shift() const {
return true;
}
Expand Down Expand Up @@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
}
}

llama_pos llama_kv_cache_unified::get_pos_max() const {
llama_pos pos_max = -1;

for (const auto & cell : cells) {
pos_max = std::max(pos_max, cell.pos);
}

return pos_max;
}

size_t llama_kv_cache_unified::total_size() const {
size_t size = 0;

Expand Down Expand Up @@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
llama_seq_id seq_id;
io.read_to(&seq_id, sizeof(seq_id));

// TODO: llama_kv_cache_unified should have a notion of max sequences
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
if (seq_id < 0) {
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
return false;
}

Expand Down Expand Up @@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
ggml_type type_v,
bool v_trans,
bool offload,
uint32_t kv_size,
bool swa_full,
uint32_t kv_size,
uint32_t n_seq_max,
uint32_t n_batch,
uint32_t padding) : hparams(model.hparams) {
uint32_t n_pad) : hparams(model.hparams) {
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };

const uint32_t size_base = kv_size;

uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));

// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
if (swa_full) {
Expand All @@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(

kv_base = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_base), type_k, type_v,
v_trans, offload, size_base, padding,
v_trans, offload, size_base, n_seq_max, n_pad,
0, LLAMA_SWA_TYPE_NONE);

LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);

kv_swa = std::make_unique<llama_kv_cache_unified>(
model, std::move(filter_swa), type_k, type_v,
v_trans, offload, size_swa, padding,
v_trans, offload, size_swa, n_seq_max, n_pad,
hparams.n_swa, hparams.swa_type);
}

Expand Down Expand Up @@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
return res;
}

int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
return kv_base->get_n_tokens();
}

int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
return kv_base->get_used_cells();
}

llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
return kv_base->get_pos_max();
}

bool llama_kv_cache_unified_iswa::get_can_shift() const {
return kv_base->get_size() == kv_swa->get_size();
}
Expand Down Expand Up @@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
ggml_type type_k,
ggml_type type_v,
bool offload,
uint32_t kv_size) : hparams(model.hparams) {
uint32_t kv_size,
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
const int32_t n_layer = hparams.n_layer;

LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);

head = 0;
size = kv_size;
used = 0;

this->type_k = type_k;
this->type_v = type_v;

cells.clear();
cells.resize(kv_size);

Expand Down Expand Up @@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
pending.ranges.clear();
}

bool llama_kv_cache_recurrent::update(llama_context & lctx) {
GGML_UNUSED(lctx);
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
GGML_UNUSED(ctx);
return false;
}

Expand Down Expand Up @@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
if (seq_id < 0 || (uint32_t) seq_id >= size) {
// too big seq_id
// TODO: would it be possible to resize the cache instead?
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
return false;
}
if (j > 0) {
Expand Down Expand Up @@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
return n >= n_seqs;
}

int32_t llama_kv_cache_recurrent::get_n_tokens() const {
int32_t result = 0;

for (uint32_t i = 0; i < size; i++) {
result += cells[i].seq_id.size();
}

return result;
}

int32_t llama_kv_cache_recurrent::get_used_cells() const {
return used;
}

llama_pos llama_kv_cache_recurrent::get_pos_max() const {
llama_pos pos_max = -1;
for (const auto & cell : cells) {
pos_max = std::max(pos_max, cell.pos);
}

return pos_max;
}

bool llama_kv_cache_recurrent::get_can_shift() const {
return false;
}
Expand Down
Loading
Loading