@@ -1280,6 +1280,7 @@ struct llama_kv_cache {
1280
1280
// cannot be freely changed after a slot has been allocated.
1281
1281
uint32_t head = 0;
1282
1282
uint32_t size = 0;
1283
+ uint32_t used = 0; // used cells (i.e. at least one seq_id)
1283
1284
1284
1285
// computed before each graph build
1285
1286
uint32_t n = 0;
@@ -1504,6 +1505,7 @@ static bool llama_kv_cache_init(
1504
1505
1505
1506
cache.head = 0;
1506
1507
cache.size = n_ctx;
1508
+ cache.used = 0;
1507
1509
1508
1510
cache.cells.clear();
1509
1511
cache.cells.resize(n_ctx);
@@ -1605,6 +1607,8 @@ static bool llama_kv_cache_find_slot(
1605
1607
}
1606
1608
}
1607
1609
1610
+ cache.used += n_tokens;
1611
+
1608
1612
return true;
1609
1613
}
1610
1614
@@ -1625,6 +1629,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
1625
1629
cache.cells[i].seq_id.clear();
1626
1630
}
1627
1631
cache.head = 0;
1632
+ cache.used = 0;
1628
1633
}
1629
1634
1630
1635
static void llama_kv_cache_seq_rm(
@@ -1647,14 +1652,17 @@ static void llama_kv_cache_seq_rm(
1647
1652
continue;
1648
1653
}
1649
1654
if (cache.cells[i].seq_id.empty()) {
1655
+ // keep count of the number of used cells
1656
+ if (cache.cells[i].pos >= 0) cache.used--;
1657
+
1650
1658
cache.cells[i].pos = -1;
1651
1659
if (new_head == cache.size) new_head = i;
1652
1660
}
1653
1661
}
1654
1662
}
1655
1663
1656
1664
// If we freed up a slot, set head to it so searching can start there.
1657
- if (new_head != cache.size ) cache.head = new_head;
1665
+ if (new_head != cache.size && new_head < cache.head ) cache.head = new_head;
1658
1666
}
1659
1667
1660
1668
static void llama_kv_cache_seq_cp(
@@ -1680,6 +1688,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1680
1688
1681
1689
for (uint32_t i = 0; i < cache.size; ++i) {
1682
1690
if (!cache.cells[i].has_seq_id(seq_id)) {
1691
+ if (cache.cells[i].pos >= 0) cache.used--;
1683
1692
cache.cells[i].pos = -1;
1684
1693
cache.cells[i].seq_id.clear();
1685
1694
if (new_head == cache.size) new_head = i;
@@ -1690,7 +1699,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
1690
1699
}
1691
1700
1692
1701
// If we freed up a slot, set head to it so searching can start there.
1693
- if (new_head != cache.size ) cache.head = new_head;
1702
+ if (new_head != cache.size && new_head < cache.head ) cache.head = new_head;
1694
1703
}
1695
1704
1696
1705
static void llama_kv_cache_seq_shift(
@@ -1711,6 +1720,7 @@ static void llama_kv_cache_seq_shift(
1711
1720
cache.cells[i].delta += delta;
1712
1721
1713
1722
if (cache.cells[i].pos < 0) {
1723
+ if (!cache.cells[i].seq_id.empty()) cache.used--;
1714
1724
cache.cells[i].pos = -1;
1715
1725
cache.cells[i].seq_id.clear();
1716
1726
if (new_head == cache.size) new_head = i;
@@ -5469,6 +5479,12 @@ static int llama_decode_internal(
5469
5479
batch.seq_id = seq_id_arr.data();
5470
5480
}
5471
5481
5482
+ // if we have enough unused cells before the current head ->
5483
+ // better to start searching from the beginning of the cache, hoping to fill it
5484
+ if (kv_self.head > kv_self.used + 2*n_tokens) {
5485
+ kv_self.head = 0;
5486
+ }
5487
+
5472
5488
if (!llama_kv_cache_find_slot(kv_self, batch)) {
5473
5489
return 1;
5474
5490
}
@@ -5479,7 +5495,7 @@ static int llama_decode_internal(
5479
5495
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
5480
5496
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
5481
5497
5482
- // printf("kv_self.n = %d \n", kv_self.n);
5498
+ //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d \n", kv_self.n, kv_self.used, kv_self.head );
5483
5499
5484
5500
ggml_allocr_reset(lctx.alloc);
5485
5501
@@ -8789,8 +8805,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
8789
8805
}
8790
8806
}
8791
8807
8808
+ struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
8809
+ struct llama_kv_cache_view result = {
8810
+ /*.n_cells = */ 0,
8811
+ /*.n_max_seq = */ n_max_seq,
8812
+ /*.token_count = */ 0,
8813
+ /*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
8814
+ /*.max_contiguous = */ 0,
8815
+ /*.max_contiguous_idx = */ -1,
8816
+ /*.cells = */ nullptr,
8817
+ /*.cells_sequences = */ nullptr,
8818
+ };
8819
+ return result;
8820
+ }
8821
+
8822
+ void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
8823
+ if (view->cells != nullptr) {
8824
+ free(view->cells);
8825
+ view->cells = nullptr;
8826
+ }
8827
+ if (view->cells_sequences != nullptr) {
8828
+ free(view->cells_sequences);
8829
+ view->cells_sequences = nullptr;
8830
+ }
8831
+ }
8832
+
8833
+ void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
8834
+ if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
8835
+ view->n_cells = int32_t(ctx->kv_self.size);
8836
+ void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
8837
+ GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
8838
+ view->cells = (struct llama_kv_cache_view_cell *)p;
8839
+ p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
8840
+ GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
8841
+ view->cells_sequences = (llama_seq_id *)p;
8842
+ }
8843
+
8844
+ const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
8845
+ llama_kv_cache_view_cell * c_curr = view->cells;
8846
+ llama_seq_id * cs_curr = view->cells_sequences;
8847
+ int32_t used_cells = 0;
8848
+ int32_t token_count = 0;
8849
+ int32_t curr_contig_idx = -1;
8850
+ uint32_t max_contig = 0;
8851
+ int32_t max_contig_idx = -1;
8852
+
8853
+ for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
8854
+ const size_t curr_size = kv_cells[i].seq_id.size();
8855
+ token_count += curr_size;
8856
+ c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
8857
+
8858
+ if (curr_size > 0) {
8859
+ if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
8860
+ max_contig = i - curr_contig_idx;
8861
+ max_contig_idx = curr_contig_idx;
8862
+ }
8863
+ curr_contig_idx = -1;
8864
+ } else if (curr_contig_idx < 0) {
8865
+ curr_contig_idx = i;
8866
+ }
8867
+
8868
+ int seq_idx = 0;
8869
+ for (const llama_seq_id it : kv_cells[i].seq_id) {
8870
+ if (seq_idx >= view->n_max_seq) {
8871
+ break;
8872
+ }
8873
+ cs_curr[seq_idx] = it;
8874
+ seq_idx++;
8875
+ }
8876
+ if (seq_idx != 0) {
8877
+ used_cells++;
8878
+ }
8879
+ for (; seq_idx < view->n_max_seq; seq_idx++) {
8880
+ cs_curr[seq_idx] = -1;
8881
+ }
8882
+ }
8883
+ if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
8884
+ max_contig_idx = curr_contig_idx;
8885
+ max_contig = kv_cells.size() - curr_contig_idx;
8886
+ }
8887
+ view->max_contiguous = max_contig;
8888
+ view->max_contiguous_idx = max_contig_idx;
8889
+ view->token_count = token_count;
8890
+ view->used_cells = used_cells;
8891
+ if (uint32_t(used_cells) != ctx->kv_self.used) {
8892
+ LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
8893
+ __func__, ctx->kv_self.used, used_cells);
8894
+ }
8895
+ }
8896
+
8792
8897
int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
8793
- return ctx->kv_self .head ;
8898
+ int result = 0;
8899
+
8900
+ for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
8901
+ result += ctx->kv_self.cells[i].seq_id.size();
8902
+ }
8903
+
8904
+ return result;
8905
+ }
8906
+
8907
+ int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
8908
+ return ctx->kv_self.used;
8794
8909
}
8795
8910
8796
8911
void llama_kv_cache_clear(struct llama_context * ctx) {
@@ -8960,10 +9075,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
8960
9075
const size_t kv_buf_size = kv_self.buf.size;
8961
9076
const uint32_t kv_head = kv_self.head;
8962
9077
const uint32_t kv_size = kv_self.size;
9078
+ const uint32_t kv_used = kv_self.used;
8963
9079
8964
9080
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
8965
9081
data_ctx->write(&kv_head, sizeof(kv_head));
8966
9082
data_ctx->write(&kv_size, sizeof(kv_size));
9083
+ data_ctx->write(&kv_used, sizeof(kv_used));
8967
9084
8968
9085
if (kv_buf_size) {
8969
9086
const size_t elt_size = ggml_element_size(kv_self.k);
@@ -9086,10 +9203,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
9086
9203
size_t kv_buf_size;
9087
9204
uint32_t kv_head;
9088
9205
uint32_t kv_size;
9206
+ uint32_t kv_used;
9089
9207
9090
9208
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
9091
9209
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
9092
9210
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
9211
+ memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
9093
9212
9094
9213
if (kv_buf_size) {
9095
9214
GGML_ASSERT(kv_self.buf.size == kv_buf_size);
@@ -9124,6 +9243,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
9124
9243
9125
9244
ctx->kv_self.head = kv_head;
9126
9245
ctx->kv_self.size = kv_size;
9246
+ ctx->kv_self.used = kv_used;
9127
9247
9128
9248
ctx->kv_self.cells.resize(kv_size);
9129
9249
0 commit comments