Skip to content

Commit 5e2183f

Browse files
Make cache traceable (#35873)
simply make cache traceable
1 parent 31bb662 commit 5e2183f

File tree

3 files changed

+21
-30
lines changed

3 files changed

+21
-30
lines changed

src/transformers/cache_utils.py

Lines changed: 6 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
from packaging import version
1010

1111
from .configuration_utils import PretrainedConfig
12-
from .utils import (
13-
is_hqq_available,
14-
is_optimum_quanto_available,
15-
is_torchdynamo_compiling,
16-
logging,
17-
)
12+
from .utils import is_hqq_available, is_optimum_quanto_available, logging
1813
from .utils.deprecation import deprecate_kwarg
1914

2015

@@ -24,7 +19,7 @@
2419
logger = logging.get_logger(__name__)
2520

2621

27-
class Cache(torch.nn.Module):
22+
class Cache:
2823
"""
2924
Base, abstract class for all caches. The actual data structure is specific to each subclass.
3025
"""
@@ -1140,18 +1135,10 @@ def __init__(
11401135
layer_device = self.device
11411136
new_layer_key_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
11421137
new_layer_value_cache = torch.zeros(cache_shape, dtype=self.dtype, device=layer_device)
1143-
# Notes:
1144-
# 1. `mark_static_address` is used to tag the cache as an fixed data pointer, preventing cuda graph
1145-
# breaks when updating the cache. It can't be used if the cache code is being compiled (but in that case
1146-
# it is not needed anyway)
1147-
# 2. `torch.export()` requires mutations to be registered as buffers.
1148-
if not is_torchdynamo_compiling():
1149-
self.register_buffer(f"key_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
1150-
self.register_buffer(f"value_cache_{idx}", torch.zeros(cache_shape, dtype=dtype, device=layer_device))
1151-
new_layer_key_cache = getattr(self, f"key_cache_{idx}")
1152-
new_layer_value_cache = getattr(self, f"value_cache_{idx}")
1153-
torch._dynamo.mark_static_address(new_layer_key_cache)
1154-
torch._dynamo.mark_static_address(new_layer_value_cache)
1138+
# Note: `mark_static_address` is used to tag the cache as a fixed data pointer,
1139+
# preventing compiled graph breaks when updating the cache.
1140+
torch._dynamo.mark_static_address(new_layer_key_cache)
1141+
torch._dynamo.mark_static_address(new_layer_value_cache)
11551142
self.key_cache.append(new_layer_key_cache)
11561143
self.value_cache.append(new_layer_value_cache)
11571144

src/transformers/integrations/executorch.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,7 @@
1616

1717

1818
if is_torch_available():
19-
from transformers import (
20-
PreTrainedModel,
21-
StaticCache,
22-
)
19+
from transformers import PreTrainedModel, StaticCache
2320
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
2421

2522

@@ -72,9 +69,13 @@ def __init__(self, model: PreTrainedModel):
7269
config=self.model.config,
7370
batch_size=self.model.generation_config.cache_config.batch_size,
7471
max_cache_len=self.model.generation_config.cache_config.max_cache_len,
75-
dtype=self.model.dtype,
7672
device=self.model.generation_config.cache_config.device,
73+
dtype=self.model.dtype,
7774
)
75+
for i in range(len(self.static_cache.key_cache)):
76+
self.register_buffer(f"key_cache_{i}", self.static_cache.key_cache[i], persistent=False)
77+
self.register_buffer(f"value_cache_{i}", self.static_cache.value_cache[i], persistent=False)
78+
7879
self.is_causal = any("CausalLM" in arch for arch in self.model.config.architectures)
7980
if self.is_causal:
8081
causal_mask = torch.tril(
@@ -109,12 +110,15 @@ def forward(self, input_ids: torch.Tensor, cache_position: torch.Tensor):
109110
"""
110111
_, seqlen = input_ids.shape
111112
attn_mask = self.mask[cache_position, :seqlen] if self.is_causal else None
113+
position_ids = cache_position.unsqueeze(0)
114+
past_key_values = self.static_cache
115+
112116
outs = self.model(
113117
input_ids=input_ids,
114118
attention_mask=attn_mask,
115-
position_ids=cache_position.unsqueeze(0),
119+
position_ids=position_ids,
116120
cache_position=cache_position,
117-
past_key_values=self.static_cache,
121+
past_key_values=past_key_values,
118122
use_cache=True,
119123
)
120124
return outs.logits
@@ -143,7 +147,7 @@ def generate(
143147
prompt_token_len = prompt_token_ids.shape[-1]
144148
max_generation_length = prompt_token_len + max_new_tokens
145149
for buffer_name, buffer in exported_program.named_buffers():
146-
if buffer_name.startswith("static_cache.key_cache"):
150+
if buffer_name.startswith("key_cache"):
147151
max_cache_len = buffer.shape[2]
148152
max_generation_length = min(max_generation_length, max_cache_len)
149153
break

tests/utils/test_cache_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,11 +215,11 @@ def test_static_cache_exportability(self):
215215
# Check if the exported model is configured with the `StaticCache` correctly
216216
n_static_key_caches = n_static_value_caches = 0
217217
for buffer_name, buffer in exported_program.named_buffers():
218-
if buffer_name.startswith("static_cache.key_cache"):
218+
if buffer_name.startswith("key_cache"):
219219
self.assertTrue(buffer.shape[0] == batch_size)
220220
self.assertTrue(buffer.shape[2] == max_cache_len)
221221
n_static_key_caches = n_static_key_caches + 1
222-
if buffer_name.startswith("static_cache.value_cache"):
222+
if buffer_name.startswith("value_cache"):
223223
self.assertTrue(buffer.shape[0] == batch_size)
224224
self.assertTrue(buffer.shape[2] == max_cache_len)
225225
n_static_value_caches = n_static_value_caches + 1
@@ -619,4 +619,4 @@ def test_cache_copy(self):
619619
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
620620
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
621621
] # fmt: skip
622-
self.assertTrue(responses == EXPECTED_DECODED_TEXT)
622+
self.assertEqual(responses, EXPECTED_DECODED_TEXT)

0 commit comments

Comments
 (0)