Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union

import torch
from packaging import version
Expand Down Expand Up @@ -358,12 +358,23 @@ class DynamicCache(Cache):
```
"""

def __init__(self) -> None:
def __init__(self, _distributed_cache_data: Iterable = None) -> None:
super().__init__()
self._seen_tokens = 0 # Used in `generate` to keep tally of how many tokens the cache has seen
self.key_cache: List[torch.Tensor] = []
self.value_cache: List[torch.Tensor] = []

# `_distributed_cache_data` was originally added for compatibility with `torch.distributed` (DDP). See #36121
# and #36373 for more information. In a nutshell, it is `map(gather_map, zip(*caches))`, i.e. each item in the
# iterable contains the key and value states for a layer gathered across replicas by torch.distributed
# (shape=[global batch size, num_heads, seq_len, head_dim]).
# WARNING: `_distributed_cache_data` must be the first argument in `__init__`, otherwise we'll break
# compatibility. The name of the argument doesn't matter.
if _distributed_cache_data is not None:
for key_states, value_states in _distributed_cache_data:
self.key_cache.append(key_states)
self.value_cache.append(value_states)

def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"""
Support for backwards-compatible `past_key_value` indexing, e.g. `past_key_value[0][0].shape[2]` to get the
Expand Down
34 changes: 34 additions & 0 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,14 @@

from transformers import set_seed
from transformers.testing_utils import (
get_gpu_count,
is_torch_available,
require_gptq,
require_non_xpu,
require_read_token,
require_torch,
require_torch_gpu,
require_torch_multi_gpu,
slow,
torch_device,
)
Expand Down Expand Up @@ -620,3 +622,35 @@ def test_cache_copy(self):
'You are a helpful assistant. What is the capital of France?\n\n\n## Response:Paris is the capital of France.\n\n\n\n\n\n## Query:\n\nIn a detailed analysis, compare the economic impacts of the introduction of the'
] # fmt: skip
self.assertEqual(responses, EXPECTED_DECODED_TEXT)

@require_torch_multi_gpu
def test_data_parallel_dynamic_cache(self):
"""
Tests that the dynamic cache works with nn.DataParallel. Under the hood, `DynamicCache` is rebuilt from
multiple `DynamicCache` in the gather step.
"""

model_repo = "hf-internal-testing/tiny-random-MistralForCausalLM"
model = AutoModelForCausalLM.from_pretrained(model_repo).to(torch_device)
tokenizer = AutoTokenizer.from_pretrained(model_repo)

# w/o DP: batch_size = num_gpu
# w DP: batch_size = 1 (with num_gpus replicas)
num_gpus = get_gpu_count()
model_inputs = tokenizer(["foo bar"] * num_gpus, return_tensors="pt").to(model.device)

# w/o DP
no_parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(no_parallelism_cache, DynamicCache)

# w DP
model = torch.nn.DataParallel(model)
parallelism_cache = model(**model_inputs).past_key_values
self.assertIsInstance(parallelism_cache, DynamicCache)

# Check that the caches are the same
for layer_idx in range(len(no_parallelism_cache)):
for kv_idx in range(2): # 0 = key, 1 = value
torch.testing.assert_close(
actual=parallelism_cache[layer_idx][kv_idx], expected=no_parallelism_cache[layer_idx][kv_idx]
)