Feature: Added api for getting/setting the complete state: rng, logits, embedding and kv_cache #1105
+133
−1
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I have implemented functions for getting and setting the rest of the model state.
It includes: random number generator state, logits, embedding and kv_cache.
It was necessary to store the logits so that we can
eval tokens
,save state
,restart program
,load state
and thensample
.With just restoring kv_cache the sampling did not have access to the required logits and indeed segfaulted on the initially empty logits vector.
The logits vector initial capacity was reserved with a wrong value. This resulted in changing capacity after the first evaluation in which the logits vector is actually resized. I fixed this bug because it propagated to the state size, resulting in unnecessarily changes.
The random number generator state is also included to ensure consistent sampling results.
Since the internal state of the rng is more than just the seed, it is serialized using the standard C++ api for this purpose by streaming into a stringbuffer. For simplicity I did not add further logic to parse and compress the serialized rng state.
For completeness I also stored the embedding vector.
Because the whole state is not in one contiguous memory buffer I decided on an output pointer parameter to get the state data.
The user is responsible to allocate the memory where the state is written to. To support this the required number of bytes can be requested.