@@ -1280,6 +1280,7 @@ struct llama_kv_cache {
12801280 // cannot be freely changed after a slot has been allocated.
12811281 uint32_t head = 0 ;
12821282 uint32_t size = 0 ;
1283+ uint32_t used = 0 ; // used cells (i.e. at least one seq_id)
12831284
12841285 // computed before each graph build
12851286 uint32_t n = 0 ;
@@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
15041505
15051506 cache.head = 0 ;
15061507 cache.size = n_ctx;
1508+ cache.used = 0 ;
15071509
15081510 cache.cells .clear ();
15091511 cache.cells .resize (n_ctx);
@@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
16051607 }
16061608 }
16071609
1610+ cache.used += n_tokens;
1611+
16081612 return true ;
16091613}
16101614
@@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
16251629 cache.cells [i].seq_id .clear ();
16261630 }
16271631 cache.head = 0 ;
1632+ cache.used = 0 ;
16281633}
16291634
16301635static void llama_kv_cache_seq_rm (
@@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
16471652 continue ;
16481653 }
16491654 if (cache.cells [i].seq_id .empty ()) {
1655+ // keep count of the number of used cells
1656+ if (cache.cells [i].pos >= 0 ) cache.used --;
1657+
16501658 cache.cells [i].pos = -1 ;
16511659 if (new_head == cache.size ) new_head = i;
16521660 }
16531661 }
16541662 }
16551663
16561664 // If we freed up a slot, set head to it so searching can start there.
1657- if (new_head != cache.size ) cache.head = new_head;
1665+ if (new_head != cache.size && new_head < cache. head ) cache.head = new_head;
16581666}
16591667
16601668static void llama_kv_cache_seq_cp (
@@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
16801688
16811689 for (uint32_t i = 0 ; i < cache.size ; ++i) {
16821690 if (!cache.cells [i].has_seq_id (seq_id)) {
1691+ if (cache.cells [i].pos >= 0 ) cache.used --;
16831692 cache.cells [i].pos = -1 ;
16841693 cache.cells [i].seq_id .clear ();
16851694 if (new_head == cache.size ) new_head = i;
@@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
16901699 }
16911700
16921701 // If we freed up a slot, set head to it so searching can start there.
1693- if (new_head != cache.size ) cache.head = new_head;
1702+ if (new_head != cache.size && new_head < cache. head ) cache.head = new_head;
16941703}
16951704
16961705static void llama_kv_cache_seq_shift (
@@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
17111720 cache.cells [i].delta += delta;
17121721
17131722 if (cache.cells [i].pos < 0 ) {
1723+ if (!cache.cells [i].seq_id .empty ()) cache.used --;
17141724 cache.cells [i].pos = -1 ;
17151725 cache.cells [i].seq_id .clear ();
17161726 if (new_head == cache.size ) new_head = i;
@@ -5469,6 +5479,12 @@ static int llama_decode_internal(
54695479 batch.seq_id = seq_id_arr.data ();
54705480 }
54715481
5482+ // if we have enough unused cells before the current head ->
5483+ // better to start searching from the beginning of the cache, hoping to fill it
5484+ if (kv_self.head > kv_self.used + 2 *n_tokens) {
5485+ kv_self.head = 0 ;
5486+ }
5487+
54725488 if (!llama_kv_cache_find_slot (kv_self, batch)) {
54735489 return 1 ;
54745490 }
@@ -5479,7 +5495,7 @@ static int llama_decode_internal(
54795495 // kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
54805496 kv_self.n = std::min ((int32_t ) cparams.n_ctx , std::max (32 , llama_kv_cache_cell_max (kv_self)));
54815497
5482- // printf("kv_self.n = %d \n", kv_self.n);
5498+ // printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d \n", kv_self.n, kv_self.used, kv_self.head );
54835499
54845500 ggml_allocr_reset (lctx.alloc );
54855501
@@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
87898805 }
87908806}
87918807
8808+ struct llama_kv_cache_view llama_kv_cache_view_init (const struct llama_context * ctx, int32_t n_max_seq) {
8809+ struct llama_kv_cache_view result = {
8810+ /* .n_cells = */ 0 ,
8811+ /* .n_max_seq = */ n_max_seq,
8812+ /* .token_count = */ 0 ,
8813+ /* .used_cells = */ llama_get_kv_cache_used_cells (ctx),
8814+ /* .max_contiguous = */ 0 ,
8815+ /* .max_contiguous_idx = */ -1 ,
8816+ /* .cells = */ nullptr ,
8817+ /* .cells_sequences = */ nullptr ,
8818+ };
8819+ return result;
8820+ }
8821+
8822+ void llama_kv_cache_view_free (struct llama_kv_cache_view * view) {
8823+ if (view->cells != nullptr ) {
8824+ free (view->cells );
8825+ view->cells = nullptr ;
8826+ }
8827+ if (view->cells_sequences != nullptr ) {
8828+ free (view->cells_sequences );
8829+ view->cells_sequences = nullptr ;
8830+ }
8831+ }
8832+
8833+ void llama_kv_cache_view_update (const struct llama_context * ctx, struct llama_kv_cache_view * view) {
8834+ if (uint32_t (view->n_cells ) < ctx->kv_self .size || view->cells == nullptr ) {
8835+ view->n_cells = int32_t (ctx->kv_self .size );
8836+ void * p = realloc (view->cells , sizeof (struct llama_kv_cache_view_cell ) * view->n_cells );
8837+ GGML_ASSERT (p != nullptr && " Failed to alloc kv_cache_view cells" );
8838+ view->cells = (struct llama_kv_cache_view_cell *)p;
8839+ p = realloc (view->cells_sequences , sizeof (llama_seq_id) * view->n_max_seq * view->n_cells );
8840+ GGML_ASSERT (p != nullptr && " Failed to alloc kv_cache_view cells sequences" );
8841+ view->cells_sequences = (llama_seq_id *)p;
8842+ }
8843+
8844+ const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self .cells ;
8845+ llama_kv_cache_view_cell * c_curr = view->cells ;
8846+ llama_seq_id * cs_curr = view->cells_sequences ;
8847+ int32_t used_cells = 0 ;
8848+ int32_t token_count = 0 ;
8849+ int32_t curr_contig_idx = -1 ;
8850+ uint32_t max_contig = 0 ;
8851+ int32_t max_contig_idx = -1 ;
8852+
8853+ for (int32_t i = 0 ; i < int32_t (ctx->kv_self .size ); i++, c_curr++, cs_curr += view->n_max_seq ) {
8854+ const size_t curr_size = kv_cells[i].seq_id .size ();
8855+ token_count += curr_size;
8856+ c_curr->pos = kv_cells[i].pos + kv_cells[i].delta ;
8857+
8858+ if (curr_size > 0 ) {
8859+ if (curr_contig_idx >= 0 && uint32_t (i - curr_contig_idx) > max_contig) {
8860+ max_contig = i - curr_contig_idx;
8861+ max_contig_idx = curr_contig_idx;
8862+ }
8863+ curr_contig_idx = -1 ;
8864+ } else if (curr_contig_idx < 0 ) {
8865+ curr_contig_idx = i;
8866+ }
8867+
8868+ int seq_idx = 0 ;
8869+ for (const llama_seq_id it : kv_cells[i].seq_id ) {
8870+ if (seq_idx >= view->n_max_seq ) {
8871+ break ;
8872+ }
8873+ cs_curr[seq_idx] = it;
8874+ seq_idx++;
8875+ }
8876+ if (seq_idx != 0 ) {
8877+ used_cells++;
8878+ }
8879+ for (; seq_idx < view->n_max_seq ; seq_idx++) {
8880+ cs_curr[seq_idx] = -1 ;
8881+ }
8882+ }
8883+ if (curr_contig_idx >= 0 && kv_cells.size () - curr_contig_idx > max_contig) {
8884+ max_contig_idx = curr_contig_idx;
8885+ max_contig = kv_cells.size () - curr_contig_idx;
8886+ }
8887+ view->max_contiguous = max_contig;
8888+ view->max_contiguous_idx = max_contig_idx;
8889+ view->token_count = token_count;
8890+ view->used_cells = used_cells;
8891+ if (uint32_t (used_cells) != ctx->kv_self .used ) {
8892+ LLAMA_LOG_ERROR (" %s: used cells mismatch. kv_cache says %d but we calculated %d\n " ,
8893+ __func__, ctx->kv_self .used , used_cells);
8894+ }
8895+ }
8896+
87928897int llama_get_kv_cache_token_count (const struct llama_context * ctx) {
8793- return ctx->kv_self .head ;
8898+ int result = 0 ;
8899+
8900+ for (uint32_t i = 0 ; i < ctx->kv_self .size ; i++) {
8901+ result += ctx->kv_self .cells [i].seq_id .size ();
8902+ }
8903+
8904+ return result;
8905+ }
8906+
8907+ int llama_get_kv_cache_used_cells (const struct llama_context * ctx) {
8908+ return ctx->kv_self .used ;
87948909}
87958910
87968911void llama_kv_cache_clear (struct llama_context * ctx) {
@@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
89609075 const size_t kv_buf_size = kv_self.buf .size ;
89619076 const uint32_t kv_head = kv_self.head ;
89629077 const uint32_t kv_size = kv_self.size ;
9078+ const uint32_t kv_used = kv_self.used ;
89639079
89649080 data_ctx->write (&kv_buf_size, sizeof (kv_buf_size));
89659081 data_ctx->write (&kv_head, sizeof (kv_head));
89669082 data_ctx->write (&kv_size, sizeof (kv_size));
9083+ data_ctx->write (&kv_used, sizeof (kv_used));
89679084
89689085 if (kv_buf_size) {
89699086 const size_t elt_size = ggml_element_size (kv_self.k );
@@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
90869203 size_t kv_buf_size;
90879204 uint32_t kv_head;
90889205 uint32_t kv_size;
9206+ uint32_t kv_used;
90899207
90909208 memcpy (&kv_buf_size, inp, sizeof (kv_buf_size)); inp += sizeof (kv_buf_size);
90919209 memcpy (&kv_head, inp, sizeof (kv_head)); inp += sizeof (kv_head);
90929210 memcpy (&kv_size, inp, sizeof (kv_size)); inp += sizeof (kv_size);
9211+ memcpy (&kv_used, inp, sizeof (kv_used)); inp += sizeof (kv_used);
90939212
90949213 if (kv_buf_size) {
90959214 GGML_ASSERT (kv_self.buf .size == kv_buf_size);
@@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
91249243
91259244 ctx->kv_self .head = kv_head;
91269245 ctx->kv_self .size = kv_size;
9246+ ctx->kv_self .used = kv_used;
91279247
91289248 ctx->kv_self .cells .resize (kv_size);
91299249
0 commit comments