Skip to content

Commit 2e66bc6

Browse files
gantetomaarsen
andcommitted
Cache class working with generate (#1)
* Draft version of new KV Caching This should allow Attention Sinks (https://github.com/tomaarsen/attention_sinks) / StreamingLLM (https://arxiv.org/abs/2309.17453) to be easily implemented in a third-party or in transformers directly * Address numerous PR suggestions 1. Move layer_idx from cache to ...Attention. Removes confusing set_layer_idx magic. 2. Always convert past_key_values to Cache instance at the start of ...Attention, removes all other isinstance calls. 3. Remove __bool__ and __getitem__ magic as they're confusing. 4. past_key_values.update(key, value, idx) now returns key, value. 5. Add use_legacy_cache flag, defaults to None, i.e. Falsey. This breaks generate for now, until 1) the cache is used is generate() or 2) use_legacy_cache is defaulted to True in generate() until we change it in another PR. 6. Separate key_cache and value_cache. Some work is still needed to see if the SinkCache can conveniently be implemented with just one update method. * Integrate (Sink)Cache with Llama FA2 * Move from/to_legacy_cache to ...Model class * Undo unnecessary newline change * Match import style * working generate * Add tests; Simplify code; Apply changes to Mistral and Persimmon * fix rebase mess * a few more manual fixes * last manual fix * propagate changes to phi * upgrade test * add use_legacy_cache docstring; beef up tests * reintroduce unwanted deletes --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com>
1 parent a40037d commit 2e66bc6

File tree

8 files changed

+232
-107
lines changed

8 files changed

+232
-107
lines changed

src/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,6 +1303,7 @@
13031303
_import_structure["activations"] = []
13041304
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
13051305
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
1306+
_import_structure["cache_utils"] = []
13061307
_import_structure["data.datasets"] = [
13071308
"GlueDataset",
13081309
"GlueDataTrainingArguments",

src/transformers/cache_utils.py

Lines changed: 41 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,40 @@
1-
from abc import ABC, abstractmethod
2-
from typing import Dict, List, Optional, Tuple, TypeVar
1+
from typing import List, Optional, Tuple
32

43
import torch
54

65

7-
T = TypeVar("T")
8-
9-
10-
class Cache(ABC):
6+
class Cache:
117
def __init__(self) -> None:
12-
self.key_cache: Dict[int, Tuple[torch.Tensor]] = {}
13-
self.value_cache: Dict[int, Tuple[torch.Tensor]] = {}
8+
self.key_cache: List[Tuple[torch.Tensor]] = []
9+
self.value_cache: List[Tuple[torch.Tensor]] = []
10+
11+
def __getitem__(self, key: int) -> List[Tuple[torch.Tensor]]:
12+
"""
13+
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
14+
sequence length.
15+
"""
16+
if key == 0:
17+
return self.key_cache
18+
elif key == 1:
19+
return self.value_cache
20+
else:
21+
raise KeyError(f"Cache only supports 0 (key) and 1 (value) indexing, got {key}")
22+
23+
def __iter__(self):
24+
"""
25+
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
26+
keys and values
27+
"""
28+
yield self.key_cache
29+
yield self.value_cache
30+
31+
def __len__(self):
32+
"""
33+
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
34+
to the number of layers in the model.
35+
"""
36+
return len(self.key_cache)
1437

15-
@abstractmethod
1638
def update(
1739
self,
1840
key_states: torch.Tensor,
@@ -21,10 +43,10 @@ def update(
2143
cos: Optional[torch.Tensor] = None,
2244
sin: Optional[torch.Tensor] = None,
2345
) -> Tuple[torch.Tensor, torch.Tensor]:
24-
pass
46+
raise NotImplementedError("Make sure to implement `update` in a subclass.")
2547

26-
def get_seq_length(self, layer_idx: int = 0) -> int:
27-
if layer_idx not in self.key_cache:
48+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
49+
if len(self.key_cache) <= layer_idx:
2850
return 0
2951
return self.key_cache[layer_idx].shape[-2]
3052

@@ -53,9 +75,9 @@ def update(
5375
cos: Optional[torch.Tensor] = None,
5476
sin: Optional[torch.Tensor] = None,
5577
) -> Tuple[torch.Tensor, torch.Tensor]:
56-
if layer_idx not in self.key_cache:
57-
self.key_cache[layer_idx] = key_states
58-
self.value_cache[layer_idx] = value_states
78+
if len(self.key_cache) <= layer_idx:
79+
self.key_cache.append(key_states)
80+
self.value_cache.append(value_states)
5981
else:
6082
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
6183
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
@@ -109,7 +131,7 @@ def get_rerotation_cos_sin(
109131
)
110132
return self.cos_sin_cache[key_states.shape[-2]]
111133

112-
def get_seq_length(self, layer_idx: int = 0) -> int:
134+
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
113135
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
114136
return min(super().get_seq_length(layer_idx), self.window_length - 1)
115137

@@ -122,10 +144,10 @@ def update(
122144
sin: Optional[torch.Tensor] = None,
123145
) -> Tuple[torch.Tensor, torch.Tensor]:
124146
# [bsz, num_heads, seq_len, head_dim]
125-
if layer_idx not in self.key_cache:
147+
if len(self.key_cache) <= layer_idx:
126148
# Empty cache
127-
self.key_cache[layer_idx] = key_states
128-
self.value_cache[layer_idx] = value_states
149+
self.key_cache.append(key_states)
150+
self.value_cache.append(value_states)
129151

130152
elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
131153
# Growing cache

src/transformers/generation/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import torch.distributed as dist
2525
from torch import nn
2626

27+
from ..cache_utils import DynamicCache
2728
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
2829
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
2930
from ..models.auto import (
@@ -3226,6 +3227,8 @@ def beam_search(
32263227
)
32273228
if model_kwargs["past_key_values"] is not None:
32283229
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
3230+
if not model_kwargs.get("use_legacy_cache"):
3231+
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])
32293232

32303233
if return_dict_in_generate and output_scores:
32313234
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
@@ -3561,6 +3564,8 @@ def beam_sample(
35613564
)
35623565
if model_kwargs["past_key_values"] is not None:
35633566
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
3567+
if not model_kwargs.get("use_legacy_cache"):
3568+
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])
35643569

35653570
if return_dict_in_generate and output_scores:
35663571
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
@@ -3948,6 +3953,8 @@ def group_beam_search(
39483953
model_kwargs["past_key_values"] = self._reorder_cache(
39493954
model_kwargs["past_key_values"], reordering_indices
39503955
)
3956+
if not model_kwargs.get("use_legacy_cache"):
3957+
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])
39513958

39523959
# increase cur_len
39533960
cur_len = cur_len + 1
@@ -4288,6 +4295,8 @@ def constrained_beam_search(
42884295
)
42894296
if model_kwargs["past_key_values"] is not None:
42904297
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
4298+
if not model_kwargs.get("use_legacy_cache"):
4299+
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])
42914300

42924301
if return_dict_in_generate and output_scores:
42934302
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))

src/transformers/models/llama/modeling_llama.py

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -284,11 +284,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
284284
class LlamaAttention(nn.Module):
285285
"""Multi-headed attention from 'Attention Is All You Need' paper"""
286286

287-
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
287+
def __init__(self, config: LlamaConfig, layer_idx: int):
288288
super().__init__()
289289
self.config = config
290-
self.attention_dropout = config.attention_dropout
291290
self.layer_idx = layer_idx
291+
self.attention_dropout = config.attention_dropout
292292
self.hidden_size = config.hidden_size
293293
self.num_heads = config.num_attention_heads
294294
self.head_dim = self.hidden_size // self.num_heads
@@ -435,7 +435,7 @@ def forward(
435435
if not output_attentions:
436436
attn_weights = None
437437

438-
return attn_output, attn_weights, (past_key_value if use_cache else None)
438+
return attn_output, attn_weights, past_key_value
439439

440440

441441
class LlamaFlashAttention2(LlamaAttention):
@@ -539,7 +539,7 @@ def forward(
539539
if not output_attentions:
540540
attn_weights = None
541541

542-
return attn_output, attn_weights, (past_key_value if use_cache else None)
542+
return attn_output, attn_weights, past_key_value
543543

544544
def _flash_attention_forward(
545545
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
@@ -640,7 +640,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query
640640

641641

642642
class LlamaDecoderLayer(nn.Module):
643-
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
643+
def __init__(self, config: LlamaConfig, layer_idx: int):
644644
super().__init__()
645645
self.hidden_size = config.hidden_size
646646
self.self_attn = (
@@ -816,6 +816,9 @@ def _init_weights(self, module):
816816
more detail.
817817
return_dict (`bool`, *optional*):
818818
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
819+
use_legacy_cache (`bool`, *optional*):
820+
If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return
821+
a subclass of `Cache`
819822
"""
820823

821824

@@ -887,7 +890,7 @@ def forward(
887890
past_key_values_length = 0
888891
if use_cache:
889892
if not isinstance(past_key_values, Cache):
890-
past_key_values = self.from_legacy_cache(past_key_values)
893+
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
891894
past_key_values_length = past_key_values.get_seq_length()
892895

893896
if position_ids is None:
@@ -964,7 +967,7 @@ def forward(
964967

965968
next_cache = None
966969
if use_cache:
967-
next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache
970+
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
968971
if not return_dict:
969972
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
970973
return BaseModelOutputWithPast(
@@ -974,12 +977,6 @@ def forward(
974977
attentions=all_self_attns,
975978
)
976979

977-
def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache:
978-
return DynamicCache.from_legacy_cache(past_key_values)
979-
980-
def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]:
981-
return past_key_values.to_legacy_cache()
982-
983980

984981
class LlamaForCausalLM(LlamaPreTrainedModel):
985982
_tied_weights_keys = ["lm_head.weight"]

0 commit comments

Comments
 (0)