Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 6 additions & 19 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
is_optimum_quanto_available,
is_torchdynamo_compiling,
logging,
)
from .utils import is_hqq_available, is_optimum_quanto_available, logging
from .utils.deprecation import deprecate_kwarg


Expand All @@ -24,7 +19,7 @@
logger = logging.get_logger(__name__)


class Cache(torch.nn.Module):
class Cache:
"""
Base, abstract class for all caches. The actual data structure is specific to each subclass.
"""
Expand Down Expand Up @@ -1144,18 +1139,10 @@ def __init__(
layer_device = self.device
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
# Notes:
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
# it is not needed anyway)
# 2. `torch.export()` requires mutations to be registered as buffers.
if not is_torchdynamo_compiling():
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
# preventing compiled graph breaks when updating the cache.
torch._dynamo.mark_static_address(new_layer_key_cache)
torch._dynamo.mark_static_address(new_layer_value_cache)
self.key_cache.append(new_layer_key_cache)
self.value_cache.append(new_layer_value_cache)

Expand Down
20 changes: 12 additions & 8 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@


if is_torch_available():
from transformers import (
PreTrainedModel,
StaticCache,
)
from transformers import PreTrainedModel, StaticCache
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3


Expand Down Expand Up @@ -72,9 +69,13 @@ def __init__(self, model: PreTrainedModel):
config=self.model.config,
batch_size=self.model.generation_config.cache_config.batch_size,
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
dtype=self.model.dtype,
device=self.model.generation_config.cache_config.device,
dtype=self.model.dtype,
)
for i in range(len(self.static_cache.key_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why is non-persistent preferred in your opinion? Probably doesn't matter too much for inference as it will always start with filling the cache with prompt tokens even if they are persistent

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

non-persistent buffers are not saved with the model's state_dict, I think in the case of a big model with long sequence length static cache, the cache tensors will have nonnegligible memory footprint when exporting+saving the model

self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)

self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
if self.is_causal:
causal_mask = torch.tril(
Expand Down Expand Up @@ -109,12 +110,15 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
"""
_, seqlen = input_ids.shape
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
position_ids = cache_position.unsqueeze(0)
past_key_values = self.static_cache

outs = self.model(
input_ids=input_ids,
attention_mask=attn_mask,
position_ids=cache_position.unsqueeze(0),
position_ids=position_ids,
cache_position=cache_position,
past_key_values=self.static_cache,
past_key_values=past_key_values,
use_cache=True,
)
return outs.logits
Expand Down Expand Up @@ -143,7 +147,7 @@ def generate(
prompt_token_len = prompt_token_ids.shape[-1]
max_generation_length = prompt_token_len + max_new_tokens
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
max_cache_len = buffer.shape[2]
max_generation_length = min(max_generation_length, max_cache_len)
break
Expand Down
8 changes: 4 additions & 4 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,11 +215,11 @@ def test_static_cache_exportability(self):
# Check if the exported model is configured with the `StaticCache` correctly
n_static_key_caches = n_static_value_caches = 0
for buffer_name, buffer in exported_program.named_buffers():
if buffer_name.startswith("static_cache.key_cache"):
if buffer_name.startswith("key_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_key_caches = n_static_key_caches + 1
if buffer_name.startswith("static_cache.value_cache"):
if buffer_name.startswith("value_cache"):
self.assertTrue(buffer.shape[0] == batch_size)
self.assertTrue(buffer.shape[2] == max_cache_len)
n_static_value_caches = n_static_value_caches + 1
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_sink_cache_iterative_prompts(self):
input_ids = gen_out

# We went well beyond the cache length
self.assertTrue(input_ids.shape[1] > cache.get_max_length() * 1.5)
self.assertTrue(input_ids.shape[1] > cache.get_max_cache_shape() * 1.5)

# And it still produces a coherent english
decoded = tokenizer.batch_decode(input_ids, skip_special_tokens=True)
Expand Down Expand Up @@ -619,4 +619,4 @@ def test_cache_copy(self):
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
self.assertEqual(responses, EXPECTED_DECODED_TEXT)
Loading