Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 99 additions & 44 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import importlib.metadata
import inspect
import json
import os
from dataclasses import dataclass
Expand All @@ -9,12 +10,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg


Expand All @@ -24,13 +20,82 @@
logger = logging.get_logger(__name__)


class Cache(torch.nn.Module):
class Cache(torch.Tensor):
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""

def __init__(self):
super().__init__()
@staticmethod
def __new__(cls, *args, **kwargs):
# We use a tensor wrapper to allow for torch script tracing when using the cache as an input in a forward method

wrapper_kwargs = {}
init_signature = inspect.signature(cls.__init__)
init_arguments = list(init_signature.parameters.keys())
init_defaults = {
k: v.default for k, v in init_signature.parameters.items() if v.default is not inspect.Parameter.empty
}

for argument in ["dtype", "device"]:
if argument in init_arguments:
arg_idx = init_arguments.index(argument)
if len(args) > arg_idx and args[arg_idx] is not None:
wrapper_kwargs[argument] = args[arg_idx]
elif kwargs.get(argument, None) is not None:
wrapper_kwargs[argument] = kwargs[argument]
elif init_defaults[argument] is not None:
wrapper_kwargs[argument] = init_defaults[argument]

if "cache_config" in init_arguments:
cache_config_idx = init_arguments.index("cache_config")
if len(args) > cache_config_idx and args[cache_config_idx] is not None:
wrapper_kwargs["device"] = args[cache_config_idx].device
elif kwargs.get("cache_config", None) is not None:
wrapper_kwargs["device"] = kwargs["cache_config"].device
elif init_defaults["cache_config"] is not None:
wrapper_kwargs["device"] = init_defaults["cache_config"].device

self = torch.Tensor._make_wrapper_subclass(cls, (), **wrapper_kwargs, requires_grad=False)
# we create a dummy empty tensor for generic tensor flattening/unflattening
self._empty_tensor = torch.tensor([], **wrapper_kwargs, requires_grad=False)
return self

@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs):
assert (
func.__name__ in cls.__dict__
), f"Class {cls.__name__} is a tensor wrapper and does not implement method {func.__name__}"
return getattr(cls, func.__name__)(*args, **kwargs)

def __repr__(self):
return f"{self.__class__.__name__}()"

def __bool__(self):
# in many places, past_key_values is checked for not being None using `if past_key_values:`
# I think `if past_key_values is not None:` should be used instead
return self is not None # True

def to(self, *args, **kwargs):
# originals
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}

# overrides
for arg in list(args) + list(kwargs.values()):
if isinstance(arg, (torch.device, str, int)):
wrapper_kwargs["device"] = arg
elif isinstance(arg, torch.dtype):
wrapper_kwargs["dtype"] = arg

# new wrapper
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs)
new_self.__dict__ = {k: v for k, v in self.__dict__.items() if k not in ["device", "dtype"]}
return new_self

def clone(self):
wrapper_kwargs = {"dtype": getattr(self, "dtype", None), "device": getattr(self, "device", None)}
new_self = torch.Tensor._make_wrapper_subclass(self.__class__, (), **wrapper_kwargs, requires_grad=False)
new_self.__dict__ = copy.deepcopy(self.__dict__)
return new_self

def update(
self,
Expand Down Expand Up @@ -304,7 +369,7 @@ class StaticCacheConfig(CacheConfig):

cache_implementation = "static"

def __init__(self, batch_size: int, max_cache_len: int, device="cpu"):
def __init__(self, batch_size: int, max_cache_len: int, device: Union[str, torch.device] = torch.device("cpu")):
self.batch_size = batch_size
self.max_cache_len = max_cache_len
self.device = device
Expand Down Expand Up @@ -361,6 +426,16 @@ class DynamicCache(Cache):
```
"""

def __tensor_flatten__(self):
return ["_empty_tensor"], {"_seen_tokens": self._seen_tokens}

@staticmethod
def __tensor_unflatten__(inner_tensors, meta, _, __):
cache = DynamicCache()
cache._seen_tokens = meta["_seen_tokens"]
cache._empty_tensor = inner_tensors["_empty_tensor"]
return cache

@deprecate_kwarg("num_hidden_layers", version="4.47.0")
def __init__(self, num_hidden_layers: Optional[int] = None) -> None:
super().__init__()
Expand Down Expand Up @@ -448,7 +523,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
or len(self.key_cache) <= layer_idx # skipped `layer_idx` and hasn't run a layer with cache after it
or len(self.key_cache[layer_idx]) == 0 # the layer has no cache
)
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else 0
layer_seq_length = self.key_cache[layer_idx].shape[-2] if not is_empty_layer else torch.tensor(0)
return layer_seq_length

def get_max_cache_shape(self) -> Optional[int]:
Expand Down Expand Up @@ -675,9 +750,6 @@ def __init__(self, cache_config: QuantizedCacheConfig) -> None:
self.axis_key = cache_config.axis_key
self.axis_value = cache_config.axis_value
self.compute_dtype = cache_config.compute_dtype
self.device = cache_config.device

super().__init__()

def update(
self,
Expand Down Expand Up @@ -777,7 +849,7 @@ def __init__(self, cache_config: CacheConfig) -> None:
raise ImportError(
f"You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. Detected version {optimum_quanto_version}."
)
from optimum.quanto import MaxOptimizer, qint2, qint4
from optimum.quanto import MaxOptimizer, qint2, qint4 # type: ignore

if self.nbits not in [2, 4]:
raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
Expand All @@ -796,7 +868,7 @@ def __init__(self, cache_config: CacheConfig) -> None:
def _quantize(self, tensor, axis):
# We have two different API since in optimum-quanto, we don't use AffineQuantizer anymore
if is_optimum_quanto_available():
from optimum.quanto import quantize_weight
from optimum.quanto import quantize_weight # type: ignore

scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
Expand Down Expand Up @@ -1105,7 +1177,7 @@ def __init__(
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
Expand All @@ -1116,7 +1188,6 @@ def __init__(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)

self.max_batch_size = batch_size or max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len

Expand All @@ -1125,8 +1196,6 @@ def __init__(
config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
)

self.dtype = dtype
self.device = torch.device(device) if device is not None else torch.device("meta")
self.num_key_value_heads = (
config.num_attention_heads
if getattr(config, "num_key_value_heads", None) is None
Expand All @@ -1144,18 +1213,10 @@ def __init__(
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_key_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down Expand Up @@ -1304,7 +1365,7 @@ def __init__(
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: torch.device = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
Expand Down Expand Up @@ -1619,7 +1680,7 @@ def __init__(
config: PretrainedConfig,
batch_size: int = None,
max_cache_len: int = None,
device: Union[torch.device, str] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
max_batch_size: Optional[int] = None,
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
Expand Down Expand Up @@ -1648,7 +1709,6 @@ def __init__(
config.num_attention_heads if config.num_key_value_heads is None else config.num_key_value_heads
)

self.device = torch.device(device) if device is not None else torch.device("meta")
layer_switch = config.sliding_window_pattern if hasattr(config, "sliding_window_pattern") else 2 # 2 is for BC
self.is_sliding = torch.tensor(
[bool((i + 1) % layer_switch) for i in range(config.num_hidden_layers)], dtype=torch.bool
Expand Down Expand Up @@ -1781,7 +1841,7 @@ def batch_size(self):
return self.max_batch_size


class MambaCache:
class MambaCache(Cache):
"""
Cache for mamba model which does not have attention mechanism and key value states.

Expand Down Expand Up @@ -1838,20 +1898,18 @@ def __init__(
config: PretrainedConfig,
batch_size: int = None,
dtype: torch.dtype = torch.float16,
device: Optional[Union[torch.device, str]] = None,
device: Union[torch.device, str] = torch.device("meta"),
max_batch_size: Optional[int] = None,
):
if batch_size is not None:
logger.warning_once(
f"The 'batch_size' argument of {self.__class__.__name__} is deprecated and will be removed in "
"v4.49. Use the more precisely named 'max_batch_size' argument instead."
)
self.dtype = dtype
self.max_batch_size = batch_size or max_batch_size
self.intermediate_size = config.intermediate_size
self.ssm_state_size = config.state_size
self.conv_kernel_size = config.conv_kernel
self.device = torch.device(device) if device is not None else torch.device("meta")

self.conv_states: List[torch.Tensor] = []
self.ssm_states: List[torch.Tensor] = []
Expand Down Expand Up @@ -1981,17 +2039,14 @@ def __init__(
config: PretrainedConfig,
max_batch_size: int,
max_cache_len: Optional[int],
device: Union[str, torch.device],
dtype: Optional[torch.dtype] = None,
device: Union[torch.device, str] = torch.device("meta"),
dtype: torch.dtype = torch.float32,
offload_device: Union[str, torch.device] = torch.device("cpu"),
layer_device_map: Optional[Dict[int, Union[str, torch.device, int]]] = None,
) -> None:
super(Cache, self).__init__()
self.max_batch_size = max_batch_size
self.max_cache_len = config.max_position_embeddings if max_cache_len is None else max_cache_len
self.device = torch.device(device) if layer_device_map is None else torch.device(layer_device_map[0])
self.offload_device = torch.device(offload_device)
self.dtype = dtype if dtype is not None else torch.float32

# Some model define a custom `head_dim` != config.hidden_size // config.num_attention_heads
head_dim = config.head_dim if hasattr(config, "head_dim") else config.hidden_size // config.num_attention_heads
Expand Down
11 changes: 6 additions & 5 deletions src/transformers/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,7 @@ def _expand_dict_for_generation(dict_to_expand):
key != "cache_position"
and dict_to_expand[key] is not None
and isinstance(dict_to_expand[key], torch.Tensor)
and not isinstance(dict_to_expand[key], Cache)
):
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
return dict_to_expand
Expand Down Expand Up @@ -4519,13 +4520,13 @@ def _split(data, full_batch_size: int, num_hidden_layers: int, split_size: int =
"""
if data is None:
return [None] * (full_batch_size // split_size)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
# New cache format
elif isinstance(data, DynamicCache) or (
isinstance(data, EncoderDecoderCache) and isinstance(data.self_attention_cache, DynamicCache)
):
return data.batch_split(full_batch_size, split_size, num_hidden_layers)
if isinstance(data, torch.Tensor):
return [data[i : i + split_size] for i in range(0, full_batch_size, split_size)]
elif isinstance(data, tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0], tuple):
Expand Down Expand Up @@ -4632,13 +4633,13 @@ def _concat(data):
"""
if any(data is None for data in data):
return None
if isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
# New cache format
elif isinstance(data[0], DynamicCache):
if isinstance(data[0], DynamicCache):
return DynamicCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], EncoderDecoderCache):
return EncoderDecoderCache.from_batch_splits(data, num_hidden_layers=num_hidden_layers)
elif isinstance(data[0], torch.Tensor):
return torch.cat(data, dim=0)
elif isinstance(data[0], tuple):
# If the elements of the tuple are also tuples (e.g., past_key_values in our earlier example)
if isinstance(data[0][0], tuple):
Expand Down
Loading
Loading