Skip to content

kv-cache : simplify the interface #13660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2025
Merged

Conversation

ggerganov
Copy link
Member

@ggerganov ggerganov commented May 20, 2025

First part of some KV cache interface refactoring and simplification.

Public API changes

  • Deprecate llama_kv_self_n_tokens
  • Deprecate llama_kv_self_used_cells

Internal llama_kv_cache changes

  • Remove llama_kv_cache::get_n_tokens()
  • Remove llama_kv_cache::get_used_cells()
  • Remove llama_kv_cache::get_pos_max()
  • Add notion of n_seq_max to the KV cache objects. Will be needed later for improving the data structures for tracking the per-sequence information.
  • Remove unused type_k and type_v members
  • Rename padding -> n_pad for consistency

Other changes

  • llama_decode() now verifies that if the input batch has pos == null it should also have seq_id == null so that we can automatically assign all tokens to seq_id == 0 starting from the max position currently in the cache. This fixes/prevents an edge case where a batch with pos == null that also has tokens with seq_id != 0 would be assigned incorrect positions by the llama_batch_allocr.
  • Remove some KV-cache related fields (like "used cells" and "tokens count") from the server's /metrics endpoint. These are too internal and implementation-specific and should not be exposed to the public.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part1 branch from ef880b3 to a91b15f Compare May 20, 2025 16:54
@ggerganov ggerganov marked this pull request as ready for review May 20, 2025 17:06
@ggerganov ggerganov requested a review from ngxson as a code owner May 20, 2025 17:06
@ggerganov ggerganov requested a review from slaren May 20, 2025 17:09
@@ -283,7 +283,7 @@ llama_batch_allocr::llama_batch_allocr(struct llama_batch in_batch, llama_pos p0
if (!batch.pos) {
pos.resize(batch.n_tokens);
for (int32_t i = 0; i < batch.n_tokens; i++) {
pos[i] = i + p0;
pos[i] = p0 + i + 1;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is a bit confusing to me. With p0 I would understand "position zero", but now this parameter seem to mean "previous max pos" instead.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was focused on simplifying the call site by absorbing the + 1 into the function, but you are right that this makes the parameter meaning more confusing. Changed back to the previous version and also added assert for p0.

@ggerganov ggerganov force-pushed the gg/kv-cache-simplify-part1 branch from 0e096d4 to e987482 Compare May 21, 2025 11:55
@ggerganov ggerganov merged commit 797f2ac into master May 21, 2025
51 of 53 checks passed
@ggerganov ggerganov deleted the gg/kv-cache-simplify-part1 branch May 21, 2025 12:11
infil00p pushed a commit to baseweight/llama.cpp that referenced this pull request May 22, 2025
* kv-cache : simplify the interface

ggml-ci

* context : revert llama_batch_allocr position change

ggml-ci
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants