Skip to content

Commit 6b0a742

Browse files
llama : KV cache view API + better KV cache management (#4170)
* llama : keep track of used KV cells + better KV cache management * llama : zero KV cache used upon clear ggml-ci * llama : allow exporting a view of the KV cache (#4180) * Allow exporting a view of the KV cache * Allow dumping the sequences per cell in common * Track max contiguous cells value and position as well * Fix max contiguous empty cells index calculation Make dump functions deal with lengths or sequences counts > 10 better * Fix off by one error in dump_kv_cache_view * Add doc comments for KV cache view functions Eliminate cell sequence struct; use llama_seq_id directly Minor cleanups * common : add -dkvc arg for enabling kv cache dumps --------- Co-authored-by: Kerfuffle <[email protected]>
1 parent d103d93 commit 6b0a742

File tree

5 files changed

+277
-7
lines changed

5 files changed

+277
-7
lines changed

common/common.cpp

+79
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include <regex>
1313
#include <sstream>
1414
#include <string>
15+
#include <unordered_map>
1516
#include <unordered_set>
1617
#include <vector>
1718
#include <cinttypes>
@@ -495,6 +496,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
495496
params.chatml = true;
496497
} else if (arg == "--infill") {
497498
params.infill = true;
499+
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
500+
params.dump_kv_cache = true;
498501
} else if (arg == "--multiline-input") {
499502
params.multiline_input = true;
500503
} else if (arg == "--simple-io") {
@@ -835,6 +838,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
835838
#endif // GGML_USE_CUBLAS
836839
#endif
837840
printf(" --verbose-prompt print prompt before generation\n");
841+
printf(" -dkvc, --dump-kv-cache\n");
842+
printf(" verbose print of the KV cache\n");
838843
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
839844
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
840845
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@@ -1386,3 +1391,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
13861391
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
13871392
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
13881393
}
1394+
1395+
//
1396+
// KV cache utils
1397+
//
1398+
1399+
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
1400+
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
1401+
1402+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
1403+
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1404+
1405+
llama_kv_cache_view_cell * c_curr = view.cells;
1406+
llama_seq_id * cs_curr = view.cells_sequences;
1407+
1408+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1409+
if (i % row_size == 0) {
1410+
printf("\n%5d: ", i);
1411+
}
1412+
int seq_count = 0;
1413+
for (int j = 0; j < view.n_max_seq; j++) {
1414+
if (cs_curr[j] >= 0) { seq_count++; }
1415+
}
1416+
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
1417+
}
1418+
1419+
printf("\n=== Done dumping\n");
1420+
}
1421+
1422+
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
1423+
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
1424+
1425+
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
1426+
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
1427+
1428+
std::unordered_map<llama_seq_id, size_t> seqs;
1429+
llama_kv_cache_view_cell * c_curr = view.cells;
1430+
llama_seq_id * cs_curr = view.cells_sequences;
1431+
1432+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1433+
for (int j = 0; j < view.n_max_seq; j++) {
1434+
if (cs_curr[j] < 0) { continue; }
1435+
if (seqs.find(cs_curr[j]) == seqs.end()) {
1436+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1437+
seqs[cs_curr[j]] = seqs.size();
1438+
}
1439+
}
1440+
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
1441+
}
1442+
1443+
printf("=== Sequence legend: ");
1444+
for (const auto & it : seqs) {
1445+
printf("%zu=%d, ", it.second, it.first);
1446+
}
1447+
printf("'+'=other sequence ids");
1448+
1449+
c_curr = view.cells;
1450+
cs_curr = view.cells_sequences;
1451+
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
1452+
if (i % row_size == 0) {
1453+
printf("\n%5d: ", i);
1454+
}
1455+
for (int j = 0; j < view.n_max_seq; j++) {
1456+
if (cs_curr[j] >= 0) {
1457+
const auto & it = seqs.find(cs_curr[j]);
1458+
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
1459+
} else {
1460+
putchar('.');
1461+
}
1462+
}
1463+
putchar(' ');
1464+
}
1465+
1466+
printf("\n=== Done dumping\n");
1467+
}

common/common.h

+11
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ struct gpt_params {
122122
bool numa = false; // attempt optimizations that help on some NUMA systems
123123
bool verbose_prompt = false; // print prompt tokens before generation
124124
bool infill = false; // use infill mode
125+
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
125126

126127
// multimodal models (see examples/llava)
127128
std::string mmproj = ""; // path to multimodal projector
@@ -218,3 +219,13 @@ std::string get_sortable_timestamp();
218219
void dump_non_result_info_yaml(
219220
FILE * stream, const gpt_params & params, const llama_context * lctx,
220221
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
222+
223+
//
224+
// KV cache utils
225+
//
226+
227+
// Dump the KV cache view with the number of sequences per cell.
228+
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
229+
230+
// Dump the KV cache view showing individual sequences in each cell (long output).
231+
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

examples/parallel/parallel.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ int main(int argc, char ** argv) {
113113
// insert new requests as soon as the previous one is done
114114
const bool cont_batching = params.cont_batching;
115115

116+
const bool dump_kv_cache = params.dump_kv_cache;
117+
116118
#ifndef LOG_DISABLE_LOGS
117119
log_set_target(log_filename_generator("parallel", "log"));
118120
LOG_TEE("Log start\n");
@@ -172,6 +174,8 @@ int main(int argc, char ** argv) {
172174
int32_t n_total_gen = 0;
173175
int32_t n_cache_miss = 0;
174176

177+
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
178+
175179
const auto t_main_start = ggml_time_us();
176180

177181
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
@@ -201,6 +205,11 @@ int main(int argc, char ** argv) {
201205
LOG_TEE("Processing requests ...\n\n");
202206

203207
while (true) {
208+
if (dump_kv_cache) {
209+
llama_kv_cache_view_update(ctx, &kvc_view);
210+
dump_kv_cache_view_seqs(kvc_view, 40);
211+
}
212+
204213
llama_batch_clear(batch);
205214

206215
// decode any currently ongoing sequences

llama.cpp

+124-4
Original file line numberDiff line numberDiff line change
@@ -1280,6 +1280,7 @@ struct llama_kv_cache {
12801280
// cannot be freely changed after a slot has been allocated.
12811281
uint32_t head = 0;
12821282
uint32_t size = 0;
1283+
uint32_t used = 0; // used cells (i.e. at least one seq_id)
12831284

12841285
// computed before each graph build
12851286
uint32_t n = 0;
@@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
15041505

15051506
cache.head = 0;
15061507
cache.size = n_ctx;
1508+
cache.used = 0;
15071509

15081510
cache.cells.clear();
15091511
cache.cells.resize(n_ctx);
@@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
16051607
}
16061608
}
16071609

1610+
cache.used += n_tokens;
1611+
16081612
return true;
16091613
}
16101614

@@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
16251629
cache.cells[i].seq_id.clear();
16261630
}
16271631
cache.head = 0;
1632+
cache.used = 0;
16281633
}
16291634

16301635
static void llama_kv_cache_seq_rm(
@@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
16471652
continue;
16481653
}
16491654
if (cache.cells[i].seq_id.empty()) {
1655+
// keep count of the number of used cells
1656+
if (cache.cells[i].pos >= 0) cache.used--;
1657+
16501658
cache.cells[i].pos = -1;
16511659
if (new_head == cache.size) new_head = i;
16521660
}
16531661
}
16541662
}
16551663

16561664
// If we freed up a slot, set head to it so searching can start there.
1657-
if (new_head != cache.size) cache.head = new_head;
1665+
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
16581666
}
16591667

16601668
static void llama_kv_cache_seq_cp(
@@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
16801688

16811689
for (uint32_t i = 0; i < cache.size; ++i) {
16821690
if (!cache.cells[i].has_seq_id(seq_id)) {
1691+
if (cache.cells[i].pos >= 0) cache.used--;
16831692
cache.cells[i].pos = -1;
16841693
cache.cells[i].seq_id.clear();
16851694
if (new_head == cache.size) new_head = i;
@@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
16901699
}
16911700

16921701
// If we freed up a slot, set head to it so searching can start there.
1693-
if (new_head != cache.size) cache.head = new_head;
1702+
if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
16941703
}
16951704

16961705
static void llama_kv_cache_seq_shift(
@@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
17111720
cache.cells[i].delta += delta;
17121721

17131722
if (cache.cells[i].pos < 0) {
1723+
if (!cache.cells[i].seq_id.empty()) cache.used--;
17141724
cache.cells[i].pos = -1;
17151725
cache.cells[i].seq_id.clear();
17161726
if (new_head == cache.size) new_head = i;
@@ -5469,6 +5479,12 @@ static int llama_decode_internal(
54695479
batch.seq_id = seq_id_arr.data();
54705480
}
54715481

5482+
// if we have enough unused cells before the current head ->
5483+
// better to start searching from the beginning of the cache, hoping to fill it
5484+
if (kv_self.head > kv_self.used + 2*n_tokens) {
5485+
kv_self.head = 0;
5486+
}
5487+
54725488
if (!llama_kv_cache_find_slot(kv_self, batch)) {
54735489
return 1;
54745490
}
@@ -5479,7 +5495,7 @@ static int llama_decode_internal(
54795495
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
54805496
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
54815497

5482-
//printf("kv_self.n = %d\n", kv_self.n);
5498+
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
54835499

54845500
ggml_allocr_reset(lctx.alloc);
54855501

@@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
87898805
}
87908806
}
87918807

8808+
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
8809+
struct llama_kv_cache_view result = {
8810+
/*.n_cells = */ 0,
8811+
/*.n_max_seq = */ n_max_seq,
8812+
/*.token_count = */ 0,
8813+
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
8814+
/*.max_contiguous = */ 0,
8815+
/*.max_contiguous_idx = */ -1,
8816+
/*.cells = */ nullptr,
8817+
/*.cells_sequences = */ nullptr,
8818+
};
8819+
return result;
8820+
}
8821+
8822+
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
8823+
if (view->cells != nullptr) {
8824+
free(view->cells);
8825+
view->cells = nullptr;
8826+
}
8827+
if (view->cells_sequences != nullptr) {
8828+
free(view->cells_sequences);
8829+
view->cells_sequences = nullptr;
8830+
}
8831+
}
8832+
8833+
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
8834+
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
8835+
view->n_cells = int32_t(ctx->kv_self.size);
8836+
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
8837+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
8838+
view->cells = (struct llama_kv_cache_view_cell *)p;
8839+
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
8840+
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
8841+
view->cells_sequences = (llama_seq_id *)p;
8842+
}
8843+
8844+
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
8845+
llama_kv_cache_view_cell * c_curr = view->cells;
8846+
llama_seq_id * cs_curr = view->cells_sequences;
8847+
int32_t used_cells = 0;
8848+
int32_t token_count = 0;
8849+
int32_t curr_contig_idx = -1;
8850+
uint32_t max_contig = 0;
8851+
int32_t max_contig_idx = -1;
8852+
8853+
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
8854+
const size_t curr_size = kv_cells[i].seq_id.size();
8855+
token_count += curr_size;
8856+
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
8857+
8858+
if (curr_size > 0) {
8859+
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
8860+
max_contig = i - curr_contig_idx;
8861+
max_contig_idx = curr_contig_idx;
8862+
}
8863+
curr_contig_idx = -1;
8864+
} else if (curr_contig_idx < 0) {
8865+
curr_contig_idx = i;
8866+
}
8867+
8868+
int seq_idx = 0;
8869+
for (const llama_seq_id it : kv_cells[i].seq_id) {
8870+
if (seq_idx >= view->n_max_seq) {
8871+
break;
8872+
}
8873+
cs_curr[seq_idx] = it;
8874+
seq_idx++;
8875+
}
8876+
if (seq_idx != 0) {
8877+
used_cells++;
8878+
}
8879+
for (; seq_idx < view->n_max_seq; seq_idx++) {
8880+
cs_curr[seq_idx] = -1;
8881+
}
8882+
}
8883+
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
8884+
max_contig_idx = curr_contig_idx;
8885+
max_contig = kv_cells.size() - curr_contig_idx;
8886+
}
8887+
view->max_contiguous = max_contig;
8888+
view->max_contiguous_idx = max_contig_idx;
8889+
view->token_count = token_count;
8890+
view->used_cells = used_cells;
8891+
if (uint32_t(used_cells) != ctx->kv_self.used) {
8892+
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
8893+
__func__, ctx->kv_self.used, used_cells);
8894+
}
8895+
}
8896+
87928897
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
8793-
return ctx->kv_self.head;
8898+
int result = 0;
8899+
8900+
for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
8901+
result += ctx->kv_self.cells[i].seq_id.size();
8902+
}
8903+
8904+
return result;
8905+
}
8906+
8907+
int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
8908+
return ctx->kv_self.used;
87948909
}
87958910

87968911
void llama_kv_cache_clear(struct llama_context * ctx) {
@@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
89609075
const size_t kv_buf_size = kv_self.buf.size;
89619076
const uint32_t kv_head = kv_self.head;
89629077
const uint32_t kv_size = kv_self.size;
9078+
const uint32_t kv_used = kv_self.used;
89639079

89649080
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
89659081
data_ctx->write(&kv_head, sizeof(kv_head));
89669082
data_ctx->write(&kv_size, sizeof(kv_size));
9083+
data_ctx->write(&kv_used, sizeof(kv_used));
89679084

89689085
if (kv_buf_size) {
89699086
const size_t elt_size = ggml_element_size(kv_self.k);
@@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
90869203
size_t kv_buf_size;
90879204
uint32_t kv_head;
90889205
uint32_t kv_size;
9206+
uint32_t kv_used;
90899207

90909208
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
90919209
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
90929210
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
9211+
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
90939212

90949213
if (kv_buf_size) {
90959214
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
@@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
91249243

91259244
ctx->kv_self.head = kv_head;
91269245
ctx->kv_self.size = kv_size;
9246+
ctx->kv_self.used = kv_used;
91279247

91289248
ctx->kv_self.cells.resize(kv_size);
91299249

0 commit comments

Comments
 (0)