Skip to content

Feature: Added api for getting/setting the complete state: rng, logits, embedding and kv_cache #1105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 22, 2023

Conversation

xaedes
Copy link
Collaborator

@xaedes xaedes commented Apr 21, 2023

I have implemented functions for getting and setting the rest of the model state.
It includes: random number generator state, logits, embedding and kv_cache.

It was necessary to store the logits so that we can eval tokens, save state, restart program, load state and then sample.
With just restoring kv_cache the sampling did not have access to the required logits and indeed segfaulted on the initially empty logits vector.

The logits vector initial capacity was reserved with a wrong value. This resulted in changing capacity after the first evaluation in which the logits vector is actually resized. I fixed this bug because it propagated to the state size, resulting in unnecessarily changes.

The random number generator state is also included to ensure consistent sampling results.
Since the internal state of the rng is more than just the seed, it is serialized using the standard C++ api for this purpose by streaming into a stringbuffer. For simplicity I did not add further logic to parse and compress the serialized rng state.
For completeness I also stored the embedding vector.

Because the whole state is not in one contiguous memory buffer I decided on an output pointer parameter to get the state data.
The user is responsible to allocate the memory where the state is written to. To support this the required number of bytes can be requested.

@xaedes xaedes force-pushed the state_persistence branch from cab1fe0 to 1c51e1f Compare April 21, 2023 15:36
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great job!

@ggerganov ggerganov merged commit b6e7f9b into ggml-org:master Apr 22, 2023

// Returns the size of the state
size_t llama_get_state_size(struct llama_context * ctx) {
const size_t s_bool = sizeof(int32_t);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s_bool is unused - is this expected?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I missed that during cleanup of unused stuff.
At one time during implementation I was saving the bool flags from llama_context, but removed it because it didnt make much sense.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants