Skip to content

Commit 745aa53

Browse files
authored
llama : deprecate llama_kv_self_ API (#14030)
* llama : deprecate llama_kv_self_ API ggml-ci * llama : allow llama_memory_(nullptr) ggml-ci * memory : add flag for optional data clear in llama_memory_clear ggml-ci
1 parent 487a5e0 commit 745aa53

File tree

34 files changed

+206
-127
lines changed

34 files changed

+206
-127
lines changed

common/common.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -934,7 +934,7 @@ struct common_init_result common_init_from_params(common_params & params) {
934934
return iparams;
935935
}
936936

937-
if (params.ctx_shift && !llama_kv_self_can_shift(lctx)) {
937+
if (params.ctx_shift && !llama_memory_can_shift(llama_get_memory(lctx))) {
938938
LOG_WRN("%s: KV cache shifting is not supported for this context, disabling KV cache shifting\n", __func__);
939939
params.ctx_shift = false;
940940
}
@@ -1041,7 +1041,7 @@ struct common_init_result common_init_from_params(common_params & params) {
10411041
if (llama_model_has_decoder(model)) {
10421042
llama_decode(lctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) params.n_batch)));
10431043
}
1044-
llama_kv_self_clear(lctx);
1044+
llama_memory_clear(llama_get_memory(lctx), true);
10451045
llama_synchronize(lctx);
10461046
llama_perf_context_reset(lctx);
10471047
llama_set_warmup(lctx, false);

common/speculative.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ llama_tokens common_speculative_gen_draft(
144144
auto & smpl = spec->smpl;
145145
auto & prompt = spec->prompt;
146146

147+
auto * mem = llama_get_memory(ctx);
148+
147149
int reuse_i = 0;
148150
int reuse_n = 0;
149151

@@ -173,7 +175,7 @@ llama_tokens common_speculative_gen_draft(
173175
result.reserve(params.n_draft);
174176

175177
if (reuse_n == 0) {
176-
llama_kv_self_clear(ctx);
178+
llama_memory_clear(mem, false);
177179

178180
prompt.clear();
179181
} else {
@@ -192,14 +194,14 @@ llama_tokens common_speculative_gen_draft(
192194
}
193195

194196
if (reuse_i > 0) {
195-
llama_kv_self_seq_rm (ctx, 0, 0, reuse_i);
196-
llama_kv_self_seq_add(ctx, 0, reuse_i, -1, -reuse_i);
197+
llama_memory_seq_rm (mem, 0, 0, reuse_i);
198+
llama_memory_seq_add(mem, 0, reuse_i, -1, -reuse_i);
197199

198200
prompt.erase(prompt.begin(), prompt.begin() + reuse_i);
199201
}
200202

201203
if (reuse_n < (int) prompt.size()) {
202-
llama_kv_self_seq_rm (ctx, 0, reuse_n, -1);
204+
llama_memory_seq_rm (mem, 0, reuse_n, -1);
203205

204206
prompt.erase(prompt.begin() + reuse_n, prompt.end());
205207
}

examples/batched.swift/Sources/main.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ if llama_decode(context, batch) != 0 {
116116
}
117117

118118
for i in 1 ..< n_parallel {
119-
llama_kv_self_seq_cp(context, 0, Int32(i), 0, batch.n_tokens)
119+
llama_memory_seq_cp(llama_get_memory(context), 0, Int32(i), 0, batch.n_tokens)
120120
}
121121

122122
if n_parallel > 1 {

examples/embedding/embedding.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
3737
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
3838

3939
// clear previous kv_cache values (irrelevant for embeddings)
40-
llama_kv_self_clear(ctx);
40+
llama_memory_clear(llama_get_memory(ctx), true);
4141

4242
// run model
4343
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/gritlm/gritlm.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
4545
}
4646

4747
// clear previous kv_cache values (irrelevant for embeddings)
48-
llama_kv_self_clear(ctx);
48+
llama_memory_clear(llama_get_memory(ctx), true);
4949
llama_set_embeddings(ctx, true);
5050
llama_set_causal_attn(ctx, false);
5151

@@ -102,7 +102,7 @@ static std::string generate(llama_context * ctx, llama_sampler * smpl, const std
102102

103103
llama_token eos_token = llama_vocab_eos(vocab);
104104

105-
llama_kv_self_clear(ctx);
105+
llama_memory_clear(llama_get_memory(ctx), true);
106106
llama_set_embeddings(ctx, false);
107107
llama_set_causal_attn(ctx, true);
108108

examples/llama.android/llama/src/main/cpp/llama-android.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
194194
}
195195

196196
batch->logits[batch->n_tokens - 1] = true;
197-
llama_kv_self_clear(context);
197+
llama_memory_clear(llama_get_memory(context), false);
198198

199199
const auto t_pp_start = ggml_time_us();
200200
if (llama_decode(context, *batch) != 0) {
@@ -206,7 +206,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
206206

207207
LOGi("Benchmark text generation (tg)");
208208

209-
llama_kv_self_clear(context);
209+
llama_memory_clear(llama_get_memory(context), false);
210210
const auto t_tg_start = ggml_time_us();
211211
for (i = 0; i < tg; i++) {
212212

@@ -223,7 +223,7 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
223223

224224
const auto t_tg_end = ggml_time_us();
225225

226-
llama_kv_self_clear(context);
226+
llama_memory_clear(llama_get_memory(context), false);
227227

228228
const auto t_pp = double(t_pp_end - t_pp_start) / 1000000.0;
229229
const auto t_tg = double(t_tg_end - t_tg_start) / 1000000.0;
@@ -448,5 +448,5 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
448448
extern "C"
449449
JNIEXPORT void JNICALL
450450
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
451-
llama_kv_self_clear(reinterpret_cast<llama_context *>(context));
451+
llama_memory_clear(llama_get_memory(reinterpret_cast<llama_context *>(context)), true);
452452
}

examples/llama.swiftui/llama.cpp.swift/LibLlama.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ actor LlamaContext {
210210
}
211211
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
212212

213-
llama_kv_self_clear(context)
213+
llama_memory_clear(llama_get_memory(context), false)
214214

215215
let t_pp_start = DispatchTime.now().uptimeNanoseconds / 1000;
216216

@@ -223,7 +223,7 @@ actor LlamaContext {
223223

224224
// bench text generation
225225

226-
llama_kv_self_clear(context)
226+
llama_memory_clear(llama_get_memory(context), false)
227227

228228
let t_tg_start = DispatchTime.now().uptimeNanoseconds / 1000;
229229

@@ -242,7 +242,7 @@ actor LlamaContext {
242242

243243
let t_tg_end = DispatchTime.now().uptimeNanoseconds / 1000;
244244

245-
llama_kv_self_clear(context)
245+
llama_memory_clear(llama_get_memory(context), false)
246246

247247
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
248248
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@@ -292,7 +292,7 @@ actor LlamaContext {
292292
func clear() {
293293
tokens_list.removeAll()
294294
temporary_invalid_cchars.removeAll()
295-
llama_kv_self_clear(context)
295+
llama_memory_clear(llama_get_memory(context), true)
296296
}
297297

298298
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {

examples/lookahead/lookahead.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@ int main(int argc, char ** argv) {
6060
llama_model * model = llama_init.model.get();
6161
llama_context * ctx = llama_init.context.get();
6262

63+
auto * mem = llama_get_memory(ctx);
64+
6365
const llama_vocab * vocab = llama_model_get_vocab(model);
6466

6567
// Tokenize the prompt
@@ -94,7 +96,7 @@ int main(int argc, char ** argv) {
9496
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1));
9597

9698
for (int s = 1; s < W + G + 1; ++s) {
97-
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
99+
llama_memory_seq_cp(mem, 0, s, -1, -1);
98100
}
99101

100102
const auto t_enc_end = ggml_time_us();
@@ -427,17 +429,17 @@ int main(int argc, char ** argv) {
427429

428430
// KV cache management
429431
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
430-
llama_kv_self_seq_rm(ctx, -1, n_past, -1);
432+
llama_memory_seq_rm(mem, -1, n_past, -1);
431433

432434
if (seq_id_best != 0) {
433435
// if a verification token matched, we keep the best sequence and remove the rest
434436
// this leads to some KV cache fragmentation
435-
llama_kv_self_seq_keep(ctx, seq_id_best);
436-
llama_kv_self_seq_cp (ctx, seq_id_best, 0, -1, -1);
437-
llama_kv_self_seq_rm (ctx, seq_id_best, -1, -1);
437+
llama_memory_seq_keep(mem, seq_id_best);
438+
llama_memory_seq_cp (mem, seq_id_best, 0, -1, -1);
439+
llama_memory_seq_rm (mem, seq_id_best, -1, -1);
438440

439441
for (int s = 1; s < W + G + 1; ++s) {
440-
llama_kv_self_seq_cp(ctx, 0, s, -1, -1);
442+
llama_memory_seq_cp(mem, 0, s, -1, -1);
441443
}
442444
}
443445
}

examples/lookup/lookup.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ int main(int argc, char ** argv){
181181

182182
// KV cache management
183183
// clean the cache of draft tokens that weren't accepted
184-
llama_kv_self_seq_rm(ctx, 0, n_past, -1);
184+
llama_memory_seq_rm(llama_get_memory(ctx), 0, n_past, -1);
185185

186186
common_batch_clear(batch_tgt);
187187
common_batch_add(batch_tgt, draft[0], n_past, { 0 }, true);

examples/parallel/parallel.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,8 @@ int main(int argc, char ** argv) {
194194
llama_model * model = llama_init.model.get();
195195
llama_context * ctx = llama_init.context.get();
196196

197+
auto * mem = llama_get_memory(ctx);
198+
197199
const llama_vocab * vocab = llama_model_get_vocab(model);
198200

199201
// load the prompts from an external file if there are any
@@ -259,7 +261,7 @@ int main(int argc, char ** argv) {
259261

260262
// assign the system KV cache to all parallel sequences
261263
for (int32_t i = 1; i <= n_clients; ++i) {
262-
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
264+
llama_memory_seq_cp(mem, 0, i, -1, -1);
263265
}
264266

265267
LOG_INF("\n");
@@ -286,9 +288,9 @@ int main(int argc, char ** argv) {
286288
if (batch.n_tokens == 0) {
287289
// all sequences have ended - clear the entire KV cache
288290
for (int i = 1; i <= n_clients; ++i) {
289-
llama_kv_self_seq_rm(ctx, i, -1, -1);
291+
llama_memory_seq_rm(mem, i, -1, -1);
290292
// but keep the system prompt
291-
llama_kv_self_seq_cp(ctx, 0, i, -1, -1);
293+
llama_memory_seq_cp(mem, 0, i, -1, -1);
292294
}
293295

294296
LOG_INF("%s: clearing the KV cache\n", __func__);
@@ -447,8 +449,8 @@ int main(int argc, char ** argv) {
447449
}
448450

449451
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
450-
llama_kv_self_seq_rm(ctx, client.id + 1, -1, -1);
451-
llama_kv_self_seq_cp(ctx, 0, client.id + 1, -1, -1);
452+
llama_memory_seq_rm(mem, client.id + 1, -1, -1);
453+
llama_memory_seq_cp(mem, 0, client.id + 1, -1, -1);
452454

453455
const auto t_main_end = ggml_time_us();
454456

examples/passkey/passkey.cpp

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -126,17 +126,19 @@ int main(int argc, char ** argv) {
126126

127127
int n_past = 0;
128128

129+
auto * mem = llama_get_memory(ctx);
130+
129131
// fill the KV cache
130132
for (int i = 0; i < n_ctx; i += n_batch) {
131133
if (i > 0 && n_grp > 1) {
132134
// if SelfExtend is enabled, we compress the position from the last batch by a factor of n_grp
133135
const int ib = i/n_batch - 1;
134136
const int bd = n_batch_grp*(n_grp - 1);
135137

136-
llama_kv_self_seq_add(ctx, 0, n_past - n_batch, n_past, ib*bd);
137-
llama_kv_self_seq_div(ctx, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
138+
llama_memory_seq_add(mem, 0, n_past - n_batch, n_past, ib*bd);
139+
llama_memory_seq_div(mem, 0, n_past - n_batch + ib*bd, n_past + ib*bd, n_grp);
138140

139-
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
141+
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
140142
}
141143

142144
common_batch_clear(batch);
@@ -166,10 +168,10 @@ int main(int argc, char ** argv) {
166168

167169
LOG_INF("%s: shifting KV cache with %d\n", __func__, n_discard);
168170

169-
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
170-
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
171+
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
172+
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);
171173

172-
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
174+
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
173175

174176
common_batch_clear(batch);
175177

@@ -195,10 +197,10 @@ int main(int argc, char ** argv) {
195197
if (n_discard > 0) {
196198
LOG_INF("%s: shifting KV cache with %d to free space for the answer\n", __func__, n_discard);
197199

198-
llama_kv_self_seq_rm (ctx, 0, n_keep , n_keep + n_discard);
199-
llama_kv_self_seq_add(ctx, 0, n_keep + n_discard, n_ctx, -n_discard);
200+
llama_memory_seq_rm (mem, 0, n_keep , n_keep + n_discard);
201+
llama_memory_seq_add(mem, 0, n_keep + n_discard, n_ctx, -n_discard);
200202

201-
n_past = llama_kv_self_seq_pos_max(ctx, 0) + 1;
203+
n_past = llama_memory_seq_pos_max(mem, 0) + 1;
202204
}
203205
}
204206

examples/retrieval/retrieval.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
8383

8484
static void batch_process(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
8585
// clear previous kv_cache values (irrelevant for embeddings)
86-
llama_kv_self_clear(ctx);
86+
llama_memory_clear(llama_get_memory(ctx), false);
8787

8888
// run model
8989
LOG_INF("%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);

examples/save-load-state/save-load-state.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ int main(int argc, char ** argv) {
196196
fprintf(stderr, "%s : seq 0 copied, %zd bytes\n", __func__, ncopy);
197197

198198
// erase whole kv
199-
llama_kv_self_clear(ctx3);
199+
llama_memory_clear(llama_get_memory(ctx3), true);
200200
fprintf(stderr, "%s : kv cache cleared\n", __func__);
201201

202202
// restore kv into seq 1

examples/simple-chat/simple-chat.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ int main(int argc, char ** argv) {
9898
auto generate = [&](const std::string & prompt) {
9999
std::string response;
100100

101-
const bool is_first = llama_kv_self_seq_pos_max(ctx, 0) == 0;
101+
const bool is_first = llama_memory_seq_pos_max(llama_get_memory(ctx), 0) == 0;
102102

103103
// tokenize the prompt
104104
const int n_prompt_tokens = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, is_first, true);
@@ -113,7 +113,7 @@ int main(int argc, char ** argv) {
113113
while (true) {
114114
// check if we have enough space in the context to evaluate this batch
115115
int n_ctx = llama_n_ctx(ctx);
116-
int n_ctx_used = llama_kv_self_seq_pos_max(ctx, 0);
116+
int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx), 0);
117117
if (n_ctx_used + batch.n_tokens > n_ctx) {
118118
printf("\033[0m\n");
119119
fprintf(stderr, "context size exceeded\n");

examples/speculative-simple/speculative-simple.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ int main(int argc, char ** argv) {
217217
{
218218
LOG_DBG("clear kv cache from any extra tokens, n_past = %d\n", n_past);
219219

220-
llama_kv_self_seq_rm(ctx_tgt, 0, n_past, -1);
220+
llama_memory_seq_rm(llama_get_memory(ctx_tgt), 0, n_past, -1);
221221
}
222222

223223
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {

0 commit comments

Comments
 (0)