-
Notifications
You must be signed in to change notification settings - Fork 31.8k
Generate: New Cache abstraction and Attention Sinks support
#26681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
35 commits
Select commit
Hold shift + click to select a range
d200a73
Draft version of new KV Caching
tomaarsen ffd7ba4
Address numerous PR suggestions
tomaarsen be0f917
Implement the SinkCache through backward+forward rotations
tomaarsen e9ffd60
Integrate (Sink)Cache with Llama FA2
tomaarsen c0e327c
Set use_legacy_cache=True as default, allows for test passes
tomaarsen 8565e9d
Move from/to_legacy_cache to ...Model class
tomaarsen 4c3bf9a
Undo unnecessary newline change
tomaarsen 1b9ec3d
Remove copy utility from deprecated OpenLlama
tomaarsen c3c4d5a
Match import style
tomaarsen a40037d
manual rebase with main
gante 2e66bc6
Cache class working with generate (#1)
gante ac766e3
move import
gante 3490e1e
add default to model_kwargs.get('use_legacy_cache')
gante 3520c47
correct failing test
gante 3534699
Apply suggestions from code review
gante f4ced8a
apply PR suggestions
gante f6e7d2e
fix failing test
gante 1510746
Apply suggestions from code review
gante 00f373b
PR comments
gante 89ffc8d
tmp commit
gante 5aa4573
add docstrings
gante 6675c20
more tests, more docstrings, add to docs
gante 7bf1fe0
derp
gante 2cd20a4
tmp commit
gante 4d87439
tmp dbg
gante 7f0fc57
more dbg
gante 7389b6b
fix beam search bug
gante 69085bf
cache can be a list of tuples in some models
gante ebd223b
fix group beam search
gante 03fa241
all but sinkcache integration tests
gante a9fe510
fix sink cache and add hard integration test
gante e370d33
now also compatible with input_embeds input
gante e7a6df7
PR comments
gante 4bff583
add Cache support to Phi+FA2
gante ee60b1c
make fixup
gante File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,298 @@ | ||
| from typing import Any, Dict, List, Optional, Tuple | ||
|
|
||
| import torch | ||
|
|
||
|
|
||
| class Cache: | ||
| """ | ||
| Base, abstract class for all caches. The actual data structure is specific to each subclass. | ||
| """ | ||
|
|
||
| def update( | ||
| self, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| layer_idx: int, | ||
| cache_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
|
|
||
| Parameters: | ||
| key_states (`torch.Tensor`): | ||
| The new key states to cache. | ||
| value_states (`torch.Tensor`): | ||
| The new value states to cache. | ||
| layer_idx (`int`): | ||
| The index of the layer to cache the states for. | ||
| cache_kwargs (`Dict[str, Any]`, `optional`): | ||
| Additional arguments for the cache subclass. These are specific to each subclass and allow new types of | ||
| cache to be created. | ||
|
|
||
| Return: | ||
| A tuple containing the updated key and value states. | ||
| """ | ||
| raise NotImplementedError("Make sure to implement `update` in a subclass.") | ||
|
|
||
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
| raise NotImplementedError("Make sure to implement `get_seq_length` in a subclass.") | ||
|
|
||
|
|
||
| class DynamicCache(Cache): | ||
| """ | ||
| A cache that grows dynamically as more tokens are generated. This is the default for generative models. | ||
|
|
||
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | ||
| `[batch_size, num_heads, seq_len, head_dim]`. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
gante marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| self.key_cache: List[torch.Tensor] = [] | ||
| self.value_cache: List[torch.Tensor] = [] | ||
| self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | ||
|
|
||
| def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]: | ||
| """ | ||
| Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the | ||
| sequence length. | ||
| """ | ||
| if layer_idx < len(self): | ||
| return (self.key_cache[layer_idx], self.value_cache[layer_idx]) | ||
| else: | ||
| raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}") | ||
|
|
||
| def __iter__(self): | ||
| """ | ||
| Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over | ||
| keys and values | ||
| """ | ||
| for layer_idx in range(len(self)): | ||
| yield (self.key_cache[layer_idx], self.value_cache[layer_idx]) | ||
|
|
||
| def __len__(self): | ||
| """ | ||
| Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds | ||
| to the number of layers in the model. | ||
| """ | ||
| return len(self.key_cache) | ||
|
|
||
| def update( | ||
| self, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| layer_idx: int, | ||
| cache_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
|
|
||
| Parameters: | ||
| key_states (`torch.Tensor`): | ||
| The new key states to cache. | ||
| value_states (`torch.Tensor`): | ||
| The new value states to cache. | ||
| layer_idx (`int`): | ||
| The index of the layer to cache the states for. | ||
| cache_kwargs (`Dict[str, Any]`, `optional`): | ||
| Additional arguments for the cache subclass. No additional arguments are used in `DynamicCache`. | ||
|
|
||
| Return: | ||
| A tuple containing the updated key and value states. | ||
| """ | ||
| # Update the number of seen tokens | ||
| if layer_idx == 0: | ||
| self.seen_tokens += key_states.shape[-2] | ||
|
|
||
| # Update the cache | ||
| if len(self.key_cache) <= layer_idx: | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
| else: | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
|
||
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
|
||
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
| if len(self.key_cache) <= layer_idx: | ||
| return 0 | ||
| return self.key_cache[layer_idx].shape[-2] | ||
|
|
||
| def reorder_cache(self, beam_idx: torch.LongTensor): | ||
| """Reorders the cache for beam search, given the selected beam indices.""" | ||
| for layer_idx in range(len(self.key_cache)): | ||
| device = self.key_cache[layer_idx].device | ||
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
| device = self.value_cache[layer_idx].device | ||
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
|
|
||
| def to_legacy_cache(self) -> Tuple[Tuple[torch.Tensor], Tuple[torch.Tensor]]: | ||
| """Converts the `DynamicCache` instance into the its equivalent in the legacy cache format.""" | ||
| legacy_cache = () | ||
| for layer_idx in range(len(self)): | ||
| legacy_cache += ((self.key_cache[layer_idx], self.value_cache[layer_idx]),) | ||
| return legacy_cache | ||
|
|
||
| @classmethod | ||
| def from_legacy_cache(cls, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None) -> "DynamicCache": | ||
| """Converts a cache in the legacy cache format into an equivalent `DynamicCache`.""" | ||
| cache = cls() | ||
| if past_key_values is not None: | ||
| for layer_idx in range(len(past_key_values)): | ||
| key_states, value_states = past_key_values[layer_idx] | ||
| cache.update(key_states, value_states, layer_idx) | ||
| return cache | ||
|
|
||
|
|
||
| class SinkCache(Cache): | ||
gante marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| A cache that as described in the [Attention Sinks paper](https://arxiv.org/abs/2309.17453). It allows the model to | ||
| generate beyond the length of its context window, without losing fluency in the conversation. As it discards past | ||
| tokens, the model will lose the ability to generate tokens that depend on the context that was discarded. | ||
|
|
||
| It stores the Key and Value states as a list of tensors, one for each layer. The expected shape for each tensor is | ||
| `[batch_size, num_heads, seq_len, head_dim]`. | ||
|
|
||
| Parameters: | ||
| window_length (`int`): | ||
| The length of the context window. | ||
| num_sink_tokens (`int`): | ||
| The number of sink tokens. See the original paper for more information. | ||
| """ | ||
|
|
||
| def __init__(self, window_length: int, num_sink_tokens: int) -> None: | ||
| self.key_cache: List[torch.Tensor] = [] | ||
| self.value_cache: List[torch.Tensor] = [] | ||
| self.window_length = window_length | ||
| self.num_sink_tokens = num_sink_tokens | ||
| self.cos_sin_cache = {} | ||
| self.seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen | ||
|
|
||
| @staticmethod | ||
| def _rotate_half(x): | ||
| x1 = x[..., : x.shape[-1] // 2] | ||
| x2 = x[..., x.shape[-1] // 2 :] | ||
| return torch.cat((-x2, x1), dim=-1) | ||
|
|
||
| def _apply_key_rotary_pos_emb( | ||
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | ||
| ) -> torch.Tensor: | ||
| rotated_key_states = (key_states * cos) + (self._rotate_half(key_states) * sin) | ||
| return rotated_key_states | ||
|
|
||
| def _get_rerotation_cos_sin( | ||
| self, key_states: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| if key_states.shape[-2] not in self.cos_sin_cache: | ||
| # Upcast to float32 temporarily for better accuracy | ||
| cos = cos.to(torch.float32) | ||
| sin = sin.to(torch.float32) | ||
|
|
||
| # Compute the cos and sin required for back- and forward-rotating to one position earlier in the sequence | ||
| original_cos = cos[self.num_sink_tokens + key_states.shape[-2] :] | ||
| shifted_cos = cos[self.num_sink_tokens : -key_states.shape[-2]] | ||
| original_sin = sin[self.num_sink_tokens + key_states.shape[-2] :] | ||
| shifted_sin = sin[self.num_sink_tokens : -key_states.shape[-2]] | ||
| rerotation_cos = original_cos * shifted_cos + original_sin * shifted_sin | ||
| rerotation_sin = -original_sin * shifted_cos + original_cos * shifted_sin | ||
|
|
||
| self.cos_sin_cache[key_states.shape[-2]] = ( | ||
| rerotation_cos.to(key_states.dtype).unsqueeze(0), | ||
| rerotation_sin.to(key_states.dtype).unsqueeze(0), | ||
| ) | ||
| return self.cos_sin_cache[key_states.shape[-2]] | ||
|
|
||
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: | ||
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" | ||
| # Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length | ||
| if len(self.key_cache) <= layer_idx: | ||
| return 0 | ||
| cache_length = self.key_cache[layer_idx].shape[-2] | ||
| return min(cache_length, self.window_length - 1) | ||
|
|
||
| def update( | ||
| self, | ||
| key_states: torch.Tensor, | ||
| value_states: torch.Tensor, | ||
| layer_idx: int, | ||
| cache_kwargs: Optional[Dict[str, Any]] = None, | ||
| ) -> Tuple[torch.Tensor, torch.Tensor]: | ||
| """ | ||
| Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`. | ||
|
|
||
| Parameters: | ||
| key_states (`torch.Tensor`): | ||
| The new key states to cache. | ||
| value_states (`torch.Tensor`): | ||
| The new value states to cache. | ||
| layer_idx (`int`): | ||
| The index of the layer to cache the states for. | ||
| cache_kwargs (`Dict[str, Any]`, `optional`): | ||
| Additional arguments for the cache subclass. The following arguments can be used in `SinkCache`: `sin`, | ||
| `cos` and `partial_rotation_size`. These arguments are used with models using RoPE, to recompute the | ||
| rotation as the tokens are shifted. | ||
|
|
||
| Return: | ||
| A tuple containing the updated key and value states. | ||
| """ | ||
| # Optional kwargs for `SinkCache` -- needed on models using RoPE. `partial_rotation_size` is used on models | ||
| # with partially rotated position embeddings, like Phi or Persimmon. | ||
| sin = cache_kwargs.get("sin") | ||
| cos = cache_kwargs.get("cos") | ||
| partial_rotation_size = cache_kwargs.get("partial_rotation_size") | ||
| using_rope = cos is not None and sin is not None | ||
|
|
||
| # Update the number of seen tokens | ||
| if layer_idx == 0: | ||
| self.seen_tokens += key_states.shape[-2] | ||
|
|
||
| # [bsz, num_heads, seq_len, head_dim] | ||
| if len(self.key_cache) <= layer_idx: | ||
| # Empty cache | ||
| self.key_cache.append(key_states) | ||
| self.value_cache.append(value_states) | ||
|
|
||
| elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length: | ||
| # Growing cache | ||
| self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) | ||
| self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) | ||
|
|
||
| else: | ||
| # Shifting cache | ||
| keys_to_keep = self.key_cache[layer_idx][ | ||
| :, :, -self.window_length + self.num_sink_tokens + key_states.shape[-2] : | ||
| ] | ||
|
|
||
| # On RoPE models, we need to recompute the Key rotation as the tokens are shifted | ||
| if using_rope: | ||
| rerotation_cos, rerotation_sin = self._get_rerotation_cos_sin(key_states, cos, sin) | ||
gante marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| if partial_rotation_size is not None: | ||
| keys_to_keep, keys_pass = ( | ||
| keys_to_keep[..., :partial_rotation_size], | ||
| keys_to_keep[..., partial_rotation_size:], | ||
| ) | ||
| keys_to_keep = self._apply_key_rotary_pos_emb(keys_to_keep, rerotation_cos, rerotation_sin) | ||
| if partial_rotation_size is not None: | ||
| keys_to_keep = torch.cat((keys_to_keep, keys_pass), dim=-1) | ||
|
|
||
| # Concatenate sink tokens, shifted & rotated tokens (if needed), and new tokens | ||
| sink_keys = self.key_cache[layer_idx][:, :, : self.num_sink_tokens] | ||
| self.key_cache[layer_idx] = torch.cat([sink_keys, keys_to_keep, key_states], dim=-2) | ||
|
|
||
| sink_values = self.value_cache[layer_idx][:, :, : self.num_sink_tokens] | ||
| values_to_keep = self.value_cache[layer_idx][ | ||
| :, :, -self.window_length + self.num_sink_tokens + value_states.shape[-2] : | ||
| ] | ||
| self.value_cache[layer_idx] = torch.cat([sink_values, values_to_keep, value_states], dim=-2) | ||
|
|
||
| return self.key_cache[layer_idx], self.value_cache[layer_idx] | ||
|
|
||
| def reorder_cache(self, beam_idx: torch.LongTensor): | ||
| """Reorders the cache for beam search, given the selected beam indices.""" | ||
| for layer_idx in range(len(self.key_cache)): | ||
| device = self.key_cache[layer_idx].device | ||
| self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
| device = self.value_cache[layer_idx].device | ||
| self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.