Skip to content

Commit a91b15f

Browse files
committed
kv-cache : simplify the interface
ggml-ci
1 parent a4090d1 commit a91b15f

File tree

9 files changed

+87
-153
lines changed

9 files changed

+87
-153
lines changed

examples/simple-chat/simple-chat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
9898
auto generate = [&](const std::string & prompt) {
9999
std::string response;
100100

101-
const bool is_first = llama_kv_self_used_cells(ctx) == 0;
101+
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
102102

103103
// tokenize the prompt
104104
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
113113
while (true) {
114114
// check if we have enough space in the context to evaluate this batch
115115
int n_ctx = llama_n_ctx(ctx);
116-
int n_ctx_used = llama_kv_self_used_cells(ctx);
116+
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
117117
if (n_ctx_used + batch.n_tokens > n_ctx) {
118118
printf("\033[0m\n");
119119
fprintf(stderr, "context size exceeded\n");

include/llama.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -610,10 +610,12 @@ extern "C" {
610610

611611
// Returns the number of tokens in the KV cache (slow, use only for debug)
612612
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
613-
LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx);
613+
DEPRECATED(LLAMA_API int32_t llama_kv_self_n_tokens(const struct llama_context * ctx),
614+
"Use llama_kv_self_seq_pos_max() instead");
614615

615616
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
616-
LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx);
617+
DEPRECATED(LLAMA_API int32_t llama_kv_self_used_cells(const struct llama_context * ctx),
618+
"Use llama_kv_self_seq_pos_max() instead");
617619

618620
// Clear the KV cache - both cell info is erased and KV data is zeroed
619621
LLAMA_API void llama_kv_self_clear(

src/llama-batch.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
283283
if (!batch.pos) {
284284
pos.resize(batch.n_tokens);
285285
for (int32_t i = 0; i < batch.n_tokens; i++) {
286-
pos[i] = i + p0;
286+
pos[i] = p0 + i + 1;
287287
}
288288
batch.pos = pos.data();
289289
}

src/llama-context.cpp

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -857,11 +857,17 @@ int llama_context::decode(llama_batch & inp_batch) {
857857
return -1;
858858
}
859859

860+
if (!inp_batch.pos) {
861+
if (inp_batch.seq_id) {
862+
LLAMA_LOG_ERROR("%s: pos == NULL, but seq_id != NULL\n", __func__);
863+
return -1;
864+
}
865+
}
866+
860867
llama_kv_cache * kv_self = static_cast<llama_kv_cache *>(memory.get());
861868

862869
// temporary allocate memory for the input batch if needed
863-
// TODO: this is incorrect for multiple sequences because get_pos_max() is the maximum across all sequences
864-
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->get_pos_max() + 1);
870+
llama_batch_allocr batch_allocr(inp_batch, inp_batch.pos ? -1 : kv_self->seq_pos_max(0));
865871

866872
const llama_batch & batch = batch_allocr.batch;
867873

@@ -2292,22 +2298,47 @@ int32_t llama_apply_adapter_cvec(
22922298
// kv cache
22932299
//
22942300

2301+
// deprecated
22952302
int32_t llama_kv_self_n_tokens(const llama_context * ctx) {
22962303
const auto * kv = ctx->get_kv_self();
22972304
if (!kv) {
22982305
return 0;
22992306
}
23002307

2301-
return kv->get_n_tokens();
2308+
int32_t res = 0;
2309+
2310+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2311+
const llama_pos p0 = kv->seq_pos_min(s);
2312+
const llama_pos p1 = kv->seq_pos_max(s);
2313+
2314+
if (p0 >= 0) {
2315+
res += (p1 - p0) + 1;
2316+
}
2317+
}
2318+
2319+
return res;
23022320
}
23032321

2322+
// deprecated
2323+
// note: this is the same as above - will be removed anyway, so it's ok
23042324
int32_t llama_kv_self_used_cells(const llama_context * ctx) {
23052325
const auto * kv = ctx->get_kv_self();
23062326
if (!kv) {
23072327
return 0;
23082328
}
23092329

2310-
return kv->get_used_cells();
2330+
int32_t res = 0;
2331+
2332+
for (uint32_t s = 0; s < ctx->get_cparams().n_seq_max; s++) {
2333+
const llama_pos p0 = kv->seq_pos_min(s);
2334+
const llama_pos p1 = kv->seq_pos_max(s);
2335+
2336+
if (p0 >= 0) {
2337+
res += (p1 - p0) + 1;
2338+
}
2339+
}
2340+
2341+
return res;
23112342
}
23122343

23132344
void llama_kv_self_clear(llama_context * ctx) {

src/llama-kv-cache.cpp

Lines changed: 24 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@ llama_kv_cache_unified::llama_kv_cache_unified(
3030
bool v_trans,
3131
bool offload,
3232
uint32_t kv_size,
33-
uint32_t padding,
33+
uint32_t n_seq_max,
34+
uint32_t n_pad,
3435
uint32_t n_swa,
35-
llama_swa_type swa_type) : model(model), hparams(model.hparams), v_trans(v_trans), padding(padding), n_swa(n_swa), swa_type(swa_type) {
36-
GGML_ASSERT(kv_size % padding == 0 && "kv_size must be a multiple of padding");
36+
llama_swa_type swa_type) :
37+
model(model), hparams(model.hparams), v_trans(v_trans),
38+
n_seq_max(n_seq_max), n_pad(n_pad), n_swa(n_swa), swa_type(swa_type) {
3739

38-
this->type_k = type_k;
39-
this->type_v = type_v;
40+
GGML_ASSERT(kv_size % n_pad == 0);
4041

4142
// create a context for each buffer type
4243
std::map<ggml_backend_buffer_type_t, ggml_context *> ctx_map;
@@ -129,8 +130,8 @@ llama_kv_cache_unified::llama_kv_cache_unified(
129130
const size_t memory_size_k = size_k_bytes();
130131
const size_t memory_size_v = size_v_bytes();
131132

132-
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6d cells, %3d layers), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
133-
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(),
133+
LLAMA_LOG_INFO("%s: size = %7.2f MiB (%6u cells, %3d layers, %2u seqs), K (%s): %7.2f MiB, V (%s): %7.2f MiB\n", __func__,
134+
(float)(memory_size_k + memory_size_v) / (1024.0f * 1024.0f), kv_size, (int) layers.size(), n_seq_max,
134135
ggml_type_name(type_k), (float)memory_size_k / (1024.0f * 1024.0f),
135136
ggml_type_name(type_v), (float)memory_size_v / (1024.0f * 1024.0f));
136137
}
@@ -442,7 +443,7 @@ bool llama_kv_cache_unified::update(llama_context & lctx) {
442443
void llama_kv_cache_unified::defrag_sched(float thold) {
443444
// - do not defrag small contexts (i.e. < 2048 tokens)
444445
// - count the padding towards the number of used tokens
445-
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + padding)/n)) : 0.0f;
446+
const float fragmentation = n >= 2048 ? std::max(0.0f, 1.0f - (float(used + n_pad)/n)) : 0.0f;
446447

447448
// queue defragmentation for next llama_kv_cache_update
448449
if (fragmentation > thold) {
@@ -558,7 +559,7 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
558559
// a heuristic, to avoid attending the full cache if it is not yet utilized
559560
// after enough generations, the benefit from this heuristic disappears
560561
// if we start defragmenting the cache, the benefit from this will be more important
561-
n = std::min(size, std::max(padding, GGML_PAD(cell_max(), padding)));
562+
n = std::min(size, std::max(n_pad, GGML_PAD(cell_max(), n_pad)));
562563

563564
#ifdef FIND_SLOT_DEBUG
564565
LLAMA_LOG_WARN("end: n = %5d, used = %5d, head = %5d, n_swa = %5d\n", n, used, head, n_swa);
@@ -567,20 +568,6 @@ bool llama_kv_cache_unified::find_slot(const llama_ubatch & ubatch) {
567568
return true;
568569
}
569570

570-
int32_t llama_kv_cache_unified::get_n_tokens() const {
571-
int32_t result = 0;
572-
573-
for (uint32_t i = 0; i < size; i++) {
574-
result += cells[i].seq_id.size();
575-
}
576-
577-
return result;
578-
}
579-
580-
int32_t llama_kv_cache_unified::get_used_cells() const {
581-
return used;
582-
}
583-
584571
bool llama_kv_cache_unified::get_can_shift() const {
585572
return true;
586573
}
@@ -802,16 +789,6 @@ void llama_kv_cache_unified::set_input_pos_bucket(ggml_tensor * dst, const llama
802789
}
803790
}
804791

805-
llama_pos llama_kv_cache_unified::get_pos_max() const {
806-
llama_pos pos_max = -1;
807-
808-
for (const auto & cell : cells) {
809-
pos_max = std::max(pos_max, cell.pos);
810-
}
811-
812-
return pos_max;
813-
}
814-
815792
size_t llama_kv_cache_unified::total_size() const {
816793
size_t size = 0;
817794

@@ -1501,11 +1478,8 @@ bool llama_kv_cache_unified::state_read_meta(llama_io_read_i & io, uint32_t cell
15011478
llama_seq_id seq_id;
15021479
io.read_to(&seq_id, sizeof(seq_id));
15031480

1504-
// TODO: llama_kv_cache_unified should have a notion of max sequences
1505-
//if (seq_id < 0 || (uint32_t) seq_id >= llama_n_seq_max(ctx)) {
1506-
if (seq_id < 0) {
1507-
//LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, llama_n_seq_max(ctx));
1508-
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, inf)\n", __func__, seq_id);
1481+
if (seq_id < 0 || (uint32_t) seq_id >= n_seq_max) {
1482+
LLAMA_LOG_ERROR("%s: invalid seq_id, %d is out of range [0, %u)\n", __func__, seq_id, n_seq_max);
15091483
return false;
15101484
}
15111485

@@ -1655,17 +1629,17 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16551629
ggml_type type_v,
16561630
bool v_trans,
16571631
bool offload,
1658-
uint32_t kv_size,
16591632
bool swa_full,
1633+
uint32_t kv_size,
16601634
uint32_t n_seq_max,
16611635
uint32_t n_batch,
1662-
uint32_t padding) : hparams(model.hparams) {
1636+
uint32_t n_pad) : hparams(model.hparams) {
16631637
llama_kv_cache_unified::layer_filter_cb filter_base = [&](int32_t il) { return !model.hparams.is_swa(il); };
16641638
llama_kv_cache_unified::layer_filter_cb filter_swa = [&](int32_t il) { return model.hparams.is_swa(il); };
16651639

16661640
const uint32_t size_base = kv_size;
16671641

1668-
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, padding));
1642+
uint32_t size_swa = std::min(size_base, GGML_PAD(hparams.n_swa*n_seq_max + n_batch, n_pad));
16691643

16701644
// when using full-size SWA cache, we set the SWA cache size to be equal to the base cache size and disable pruning
16711645
if (swa_full) {
@@ -1680,14 +1654,14 @@ llama_kv_cache_unified_iswa::llama_kv_cache_unified_iswa(
16801654

16811655
kv_base = std::make_unique<llama_kv_cache_unified>(
16821656
model, std::move(filter_base), type_k, type_v,
1683-
v_trans, offload, size_base, padding,
1657+
v_trans, offload, size_base, n_seq_max, n_pad,
16841658
0, LLAMA_SWA_TYPE_NONE);
16851659

16861660
LLAMA_LOG_INFO("%s: creating SWA KV cache, size = %u cells\n", __func__, size_swa);
16871661

16881662
kv_swa = std::make_unique<llama_kv_cache_unified>(
16891663
model, std::move(filter_swa), type_k, type_v,
1690-
v_trans, offload, size_swa, padding,
1664+
v_trans, offload, size_swa, n_seq_max, n_pad,
16911665
hparams.n_swa, hparams.swa_type);
16921666
}
16931667

@@ -1810,18 +1784,6 @@ bool llama_kv_cache_unified_iswa::find_slot(const llama_ubatch & batch) {
18101784
return res;
18111785
}
18121786

1813-
int32_t llama_kv_cache_unified_iswa::get_n_tokens() const {
1814-
return kv_base->get_n_tokens();
1815-
}
1816-
1817-
int32_t llama_kv_cache_unified_iswa::get_used_cells() const {
1818-
return kv_base->get_used_cells();
1819-
}
1820-
1821-
llama_pos llama_kv_cache_unified_iswa::get_pos_max() const {
1822-
return kv_base->get_pos_max();
1823-
}
1824-
18251787
bool llama_kv_cache_unified_iswa::get_can_shift() const {
18261788
return kv_base->get_size() == kv_swa->get_size();
18271789
}
@@ -1853,19 +1815,17 @@ llama_kv_cache_recurrent::llama_kv_cache_recurrent(
18531815
ggml_type type_k,
18541816
ggml_type type_v,
18551817
bool offload,
1856-
uint32_t kv_size) : hparams(model.hparams) {
1818+
uint32_t kv_size,
1819+
uint32_t n_seq_max) : hparams(model.hparams), n_seq_max(n_seq_max) {
18571820
const int32_t n_layer = hparams.n_layer;
18581821

1859-
LLAMA_LOG_INFO("%s: kv_size = %d, type_k = '%s', type_v = '%s', n_layer = %d\n",
1860-
__func__, kv_size, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
1822+
LLAMA_LOG_INFO("%s: kv_size = %u, n_seq_max = %u, type_k = '%s', type_v = '%s', n_layer = %d\n",
1823+
__func__, kv_size, n_seq_max, ggml_type_name(type_k), ggml_type_name(type_v), n_layer);
18611824

18621825
head = 0;
18631826
size = kv_size;
18641827
used = 0;
18651828

1866-
this->type_k = type_k;
1867-
this->type_v = type_v;
1868-
18691829
cells.clear();
18701830
cells.resize(kv_size);
18711831

@@ -2203,8 +2163,8 @@ void llama_kv_cache_recurrent::commit() {
22032163
pending.ranges.clear();
22042164
}
22052165

2206-
bool llama_kv_cache_recurrent::update(llama_context & lctx) {
2207-
GGML_UNUSED(lctx);
2166+
bool llama_kv_cache_recurrent::update(llama_context & ctx) {
2167+
GGML_UNUSED(ctx);
22082168
return false;
22092169
}
22102170

@@ -2265,7 +2225,7 @@ bool llama_kv_cache_recurrent::find_slot(
22652225
if (seq_id < 0 || (uint32_t) seq_id >= size) {
22662226
// too big seq_id
22672227
// TODO: would it be possible to resize the cache instead?
2268-
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%d Try using a bigger --parallel value\n", __func__, seq_id, size);
2228+
LLAMA_LOG_ERROR("%s: seq_id=%d >= n_seq_max=%u Try using a bigger --parallel value\n", __func__, seq_id, n_seq_max);
22692229
return false;
22702230
}
22712231
if (j > 0) {
@@ -2408,29 +2368,6 @@ bool llama_kv_cache_recurrent::find_slot(
24082368
return n >= n_seqs;
24092369
}
24102370

2411-
int32_t llama_kv_cache_recurrent::get_n_tokens() const {
2412-
int32_t result = 0;
2413-
2414-
for (uint32_t i = 0; i < size; i++) {
2415-
result += cells[i].seq_id.size();
2416-
}
2417-
2418-
return result;
2419-
}
2420-
2421-
int32_t llama_kv_cache_recurrent::get_used_cells() const {
2422-
return used;
2423-
}
2424-
2425-
llama_pos llama_kv_cache_recurrent::get_pos_max() const {
2426-
llama_pos pos_max = -1;
2427-
for (const auto & cell : cells) {
2428-
pos_max = std::max(pos_max, cell.pos);
2429-
}
2430-
2431-
return pos_max;
2432-
}
2433-
24342371
bool llama_kv_cache_recurrent::get_can_shift() const {
24352372
return false;
24362373
}

0 commit comments

Comments
 (0)