Skip to content

kv cache slot search improvements #3493

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Oct 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 35 additions & 6 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1044,6 +1044,9 @@ struct llama_kv_cell {
struct llama_kv_cache {
bool has_shift = false;

// Note: The value of head isn't only used to optimize searching
// for a free KV slot. llama_decode_internal also uses it, so it
// cannot be freely changed after a slot has been allocated.
uint32_t head = 0;
uint32_t size = 0;

Expand Down Expand Up @@ -1301,6 +1304,8 @@ static bool llama_kv_cache_init(

// find an empty slot of size "n_tokens" in the cache
// updates the cache head
// Note: On success, it's important that cache.head points
// to the first cell of the slot.
static bool llama_kv_cache_find_slot(
struct llama_kv_cache & cache,
const struct llama_batch & batch) {
Expand All @@ -1316,8 +1321,8 @@ static bool llama_kv_cache_find_slot(

while (true) {
if (cache.head + n_tokens > n_ctx) {
n_tested += n_ctx - cache.head;
cache.head = 0;
n_tested += n_ctx - cache.head;
continue;
}

Expand Down Expand Up @@ -1368,13 +1373,18 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0,
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
}

// Searching for a free slot can start here since we know it will be empty.
cache.head = uint32_t(c0);
}

static void llama_kv_cache_seq_rm(
struct llama_kv_cache & cache,
llama_seq_id seq_id,
llama_pos p0,
llama_pos p1) {
uint32_t new_head = cache.size;
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic for these few functions follows basically the same pattern. We'll set new_head to the first freed slot if there is one. new_head = cache.size is a safe but invalid value we know will fit in uint32_t.

This could maybe be improved by only changing cache.head if it 1) doesn't already point at a free slot, and 2) points at an index greater than the one that got freed up. The idea would be to try to maximize using slots near the beginning of the cache. I'm not sure doing this is really worth the complexity though.

These changes (even in the simple form) will slow down the cache manipulation functions a little. I think that at least is worth it because searching for a slot is probably the most common case.


if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1383,9 +1393,13 @@ static void llama_kv_cache_seq_rm(
cache.cells[i].seq_id.erase(seq_id);
if (cache.cells[i].seq_id.empty()) {
cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_cp(
Expand All @@ -1397,6 +1411,8 @@ static void llama_kv_cache_seq_cp(
if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

cache.head = 0;

for (uint32_t i = 0; i < cache.size; ++i) {
if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) {
cache.cells[i].seq_id.insert(seq_id_dst);
Expand All @@ -1405,12 +1421,18 @@ static void llama_kv_cache_seq_cp(
}

static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) {
uint32_t new_head = cache.size;

for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
}
}

// If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head;
}

static void llama_kv_cache_seq_shift(
Expand All @@ -1419,6 +1441,8 @@ static void llama_kv_cache_seq_shift(
llama_pos p0,
llama_pos p1,
llama_pos delta) {
uint32_t new_head = cache.size;

if (p0 < 0) p0 = 0;
if (p1 < 0) p1 = std::numeric_limits<llama_pos>::max();

Expand All @@ -1428,12 +1452,17 @@ static void llama_kv_cache_seq_shift(
if (cache.cells[i].pos < 0) {
cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i;
} else {
cache.has_shift = true;
cache.cells[i].delta = delta;
}
}
}

// If we freed up a slot, set head to it so searching can start there.
// Otherwise we just start the next search from the beginning.
cache.head = new_head != cache.size ? new_head : 0;
}

//
Expand Down Expand Up @@ -4454,10 +4483,6 @@ static int llama_decode_internal(
batch.seq_id = seq_id.data();
}

// we always start to search for a free slot from the start of the cache
// TODO: better strategies can be implemented
kv_self.head = 0;

if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1;
}
Expand Down Expand Up @@ -4543,8 +4568,12 @@ static int llama_decode_internal(
#endif

// update the kv ring buffer
lctx.kv_self.head += n_tokens;
lctx.kv_self.has_shift = false;
lctx.kv_self.head += n_tokens;
// Ensure kv cache head points to a valid index.
if (lctx.kv_self.head >= lctx.kv_self.size) {
lctx.kv_self.head = 0;
}

#ifdef GGML_PERF
// print timing information per ggml operation (for debugging purposes)
Expand Down