Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1037,6 +1037,7 @@
_import_structure["activations"] = []
_import_structure["benchmark.benchmark"] = ["PyTorchBenchmark"]
_import_structure["benchmark.benchmark_args"] = ["PyTorchBenchmarkArguments"]
_import_structure["cache_utils"] = []
_import_structure["data.datasets"] = [
"GlueDataset",
"GlueDataTrainingArguments",
Expand Down
60 changes: 41 additions & 19 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,40 @@
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Tuple, TypeVar
from typing import List, Optional, Tuple

import torch


T = TypeVar("T")


class Cache(ABC):
class Cache:
def __init__(self) -> None:
self.key_cache: Dict[int, Tuple[torch.Tensor]] = {}
self.value_cache: Dict[int, Tuple[torch.Tensor]] = {}
self.key_cache: List[Tuple[torch.Tensor]] = []
self.value_cache: List[Tuple[torch.Tensor]] = []

def __getitem__(self, key: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
sequence length.
"""
if key == 0:
return self.key_cache
elif key == 1:
return self.value_cache
else:
raise KeyError(f"Cache only supports 0 (key) and 1 (value) indexing, got {key}")

def __iter__(self):
"""
Support for backwards-compatible `past_key_value` iteration, e.g. `for x in past_key_value:` to iterate over
keys and values
"""
yield self.key_cache
yield self.value_cache

def __len__(self):
"""
Support for backwards-compatible `past_key_value` length, e.g. `len(past_key_value)`. This value corresponds
to the number of layers in the model.
"""
return len(self.key_cache)

@abstractmethod
def update(
self,
key_states: torch.Tensor,
Expand All @@ -21,10 +43,10 @@ def update(
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
pass
raise NotImplementedError("Make sure to implement `update` in a subclass.")

def get_seq_length(self, layer_idx: int = 0) -> int:
if layer_idx not in self.key_cache:
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
if len(self.key_cache) <= layer_idx:
return 0
return self.key_cache[layer_idx].shape[-2]

Expand Down Expand Up @@ -53,9 +75,9 @@ def update(
cos: Optional[torch.Tensor] = None,
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
if layer_idx not in self.key_cache:
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
if len(self.key_cache) <= layer_idx:
self.key_cache.append(key_states)
self.value_cache.append(value_states)
else:
self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
Expand Down Expand Up @@ -109,7 +131,7 @@ def get_rerotation_cos_sin(
)
return self.cos_sin_cache[key_states.shape[-2]]

def get_seq_length(self, layer_idx: int = 0) -> int:
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
# Workaround to make 'key_states.shape[-2] + past_key_value.get_seq_length(self.layer_idx)' <= window_length
return min(super().get_seq_length(layer_idx), self.window_length - 1)

Expand All @@ -122,10 +144,10 @@ def update(
sin: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
# [bsz, num_heads, seq_len, head_dim]
if layer_idx not in self.key_cache:
if len(self.key_cache) <= layer_idx:
# Empty cache
self.key_cache[layer_idx] = key_states
self.value_cache[layer_idx] = value_states
self.key_cache.append(key_states)
self.value_cache.append(value_states)

elif key_states.shape[-2] + self.get_seq_length(layer_idx) < self.window_length:
# Growing cache
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch.distributed as dist
from torch import nn

from ..cache_utils import DynamicCache
from ..integrations.deepspeed import is_deepspeed_zero3_enabled
from ..modeling_outputs import CausalLMOutputWithPast, Seq2SeqLMOutput
from ..models.auto import (
Expand Down Expand Up @@ -3263,6 +3264,8 @@ def beam_search(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down Expand Up @@ -3598,6 +3601,8 @@ def beam_sample(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down Expand Up @@ -3985,6 +3990,8 @@ def group_beam_search(
model_kwargs["past_key_values"] = self._reorder_cache(
model_kwargs["past_key_values"], reordering_indices
)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

# increase cur_len
cur_len = cur_len + 1
Expand Down Expand Up @@ -4325,6 +4332,8 @@ def constrained_beam_search(
)
if model_kwargs["past_key_values"] is not None:
model_kwargs["past_key_values"] = self._reorder_cache(model_kwargs["past_key_values"], beam_idx)
if not model_kwargs.get("use_legacy_cache"):
model_kwargs["past_key_values"] = DynamicCache.from_legacy_cache(model_kwargs["past_key_values"])

if return_dict_in_generate and output_scores:
beam_indices = tuple((beam_indices[beam_idx[i]] + (beam_idx[i],) for i in range(len(beam_indices))))
Expand Down
23 changes: 10 additions & 13 deletions src/transformers/models/llama/modeling_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,11 +279,11 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""

def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
def __init__(self, config: LlamaConfig, layer_idx: int):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer_idx is not an optional argument as it is required in the cache update :)

super().__init__()
self.config = config
self.attention_dropout = config.attention_dropout
self.layer_idx = layer_idx
self.attention_dropout = config.attention_dropout
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
Expand Down Expand Up @@ -430,7 +430,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, (past_key_value if use_cache else None)
return attn_output, attn_weights, past_key_value
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(past_key_value if use_cache else None) is somewhat redundant since past_key_value is already None if use_cache=False, so we can keep things simpler here 🙌

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome!



class LlamaFlashAttention2(LlamaAttention):
Expand Down Expand Up @@ -524,7 +524,7 @@ def forward(
if not output_attentions:
attn_weights = None

return attn_output, attn_weights, (past_key_value if use_cache else None)
return attn_output, attn_weights, past_key_value

def _flash_attention_forward(
self, query_states, key_states, value_states, attention_mask, query_length, dropout=0.0, softmax_scale=None
Expand Down Expand Up @@ -619,7 +619,7 @@ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query


class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig, layer_idx: Optional[int] = None):
def __init__(self, config: LlamaConfig, layer_idx: int):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
Expand Down Expand Up @@ -795,6 +795,9 @@ def _init_weights(self, module):
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
use_legacy_cache (`bool`, *optional*):
If set to `True` (default), will return `past_key_values` as described input above. Otherwise, will return
a subclass of `Cache`
"""


Expand Down Expand Up @@ -866,7 +869,7 @@ def forward(
past_key_values_length = 0
if use_cache:
if not isinstance(past_key_values, Cache):
past_key_values = self.from_legacy_cache(past_key_values)
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_seq_length()

if position_ids is None:
Expand Down Expand Up @@ -943,7 +946,7 @@ def forward(

next_cache = None
if use_cache:
next_cache = self.to_legacy_cache(next_decoder_cache) if use_legacy_cache else next_decoder_cache
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
Expand All @@ -953,12 +956,6 @@ def forward(
attentions=all_self_attns,
)

def from_legacy_cache(self, past_key_values: Optional[Tuple[Tuple[torch.Tensor]]]) -> Cache:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed these thin wrappers since a) we usually avoid thin wrappers; b) less code to transition :D

return DynamicCache.from_legacy_cache(past_key_values)

def to_legacy_cache(self, past_key_values: Cache) -> Tuple[Tuple[torch.Tensor]]:
return past_key_values.to_legacy_cache()


class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
Expand Down
Loading