From abafd01ec852307d5f57c4010474ba870b2a3f22 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Thu, 5 Oct 2023 16:06:55 -0600 Subject: [PATCH 1/3] kv cache slot search improvements --- llama.cpp | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/llama.cpp b/llama.cpp index 08d6c162a5d7c..75916e54df8eb 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1316,8 +1316,8 @@ static bool llama_kv_cache_find_slot( while (true) { if (cache.head + n_tokens > n_ctx) { + n_tested += cache.size - cache.head; cache.head = 0; - n_tested += n_ctx - cache.head; continue; } @@ -1368,6 +1368,9 @@ static void llama_kv_cache_tokens_rm(struct llama_kv_cache & cache, int32_t c0, cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); } + + // Searching for a free slot can start here since we know it will be empty. + cache.head = uint32_t(c0); } static void llama_kv_cache_seq_rm( @@ -1375,6 +1378,8 @@ static void llama_kv_cache_seq_rm( llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + uint32_t new_head = cache.size; + if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); @@ -1383,9 +1388,13 @@ static void llama_kv_cache_seq_rm( cache.cells[i].seq_id.erase(seq_id); if (cache.cells[i].seq_id.empty()) { cache.cells[i].pos = -1; + if (new_head == cache.size) new_head = i; } } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_cp( @@ -1397,6 +1406,8 @@ static void llama_kv_cache_seq_cp( if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); + cache.head = 0; + for (uint32_t i = 0; i < cache.size; ++i) { if (cache.cells[i].has_seq_id(seq_id_src) && cache.cells[i].pos >= p0 && cache.cells[i].pos < p1) { cache.cells[i].seq_id.insert(seq_id_dst); @@ -1405,12 +1416,18 @@ static void llama_kv_cache_seq_cp( } static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id seq_id) { + uint32_t new_head = cache.size; + for (uint32_t i = 0; i < cache.size; ++i) { if (!cache.cells[i].has_seq_id(seq_id)) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } } + + // If we freed up a slot, set head to it so searching can start there. + if (new_head != cache.size) cache.head = new_head; } static void llama_kv_cache_seq_shift( @@ -1419,6 +1436,8 @@ static void llama_kv_cache_seq_shift( llama_pos p0, llama_pos p1, llama_pos delta) { + uint32_t new_head = cache.size; + if (p0 < 0) p0 = 0; if (p1 < 0) p1 = std::numeric_limits::max(); @@ -1428,12 +1447,17 @@ static void llama_kv_cache_seq_shift( if (cache.cells[i].pos < 0) { cache.cells[i].pos = -1; cache.cells[i].seq_id.clear(); + if (new_head == cache.size) new_head = i; } else { cache.has_shift = true; cache.cells[i].delta = delta; } } } + + // If we freed up a slot, set head to it so searching can start there. + // Otherwise we just start the next search from the beginning. + cache.head = new_head != cache.size ? new_head : 0; } // @@ -4454,10 +4478,6 @@ static int llama_decode_internal( batch.seq_id = seq_id.data(); } - // we always start to search for a free slot from the start of the cache - // TODO: better strategies can be implemented - kv_self.head = 0; - if (!llama_kv_cache_find_slot(kv_self, batch)) { return 1; } From 3144563db19a246dc3607c62e18e7465789a33ea Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Fri, 6 Oct 2023 06:23:49 -0600 Subject: [PATCH 2/3] Use n_ctx in kv find slot for consistency --- llama.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index 75916e54df8eb..ce4d68f38f5bd 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1316,7 +1316,7 @@ static bool llama_kv_cache_find_slot( while (true) { if (cache.head + n_tokens > n_ctx) { - n_tested += cache.size - cache.head; + n_tested += n_ctx - cache.head; cache.head = 0; continue; } From 465b8f4fc0ae54ddbbbe891ade101f9a17ef30f8 Mon Sep 17 00:00:00 2001 From: KerfuffleV2 Date: Fri, 6 Oct 2023 07:33:10 -0600 Subject: [PATCH 3/3] Ensure kv cache head points to a valid slot in llama_decode internal Add some comments to prevent dumb people (like me) from getting confused. --- llama.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/llama.cpp b/llama.cpp index ce4d68f38f5bd..bf640bc020519 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1044,6 +1044,9 @@ struct llama_kv_cell { struct llama_kv_cache { bool has_shift = false; + // Note: The value of head isn't only used to optimize searching + // for a free KV slot. llama_decode_internal also uses it, so it + // cannot be freely changed after a slot has been allocated. uint32_t head = 0; uint32_t size = 0; @@ -1301,6 +1304,8 @@ static bool llama_kv_cache_init( // find an empty slot of size "n_tokens" in the cache // updates the cache head +// Note: On success, it's important that cache.head points +// to the first cell of the slot. static bool llama_kv_cache_find_slot( struct llama_kv_cache & cache, const struct llama_batch & batch) { @@ -4563,8 +4568,12 @@ static int llama_decode_internal( #endif // update the kv ring buffer - lctx.kv_self.head += n_tokens; lctx.kv_self.has_shift = false; + lctx.kv_self.head += n_tokens; + // Ensure kv cache head points to a valid index. + if (lctx.kv_self.head >= lctx.kv_self.size) { + lctx.kv_self.head = 0; + } #ifdef GGML_PERF // print timing information per ggml operation (for debugging purposes)