Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 18 additions & 7 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from packaging import version

from .configuration_utils import PretrainedConfig
from .pytorch_utils import is_torch_greater_or_equal_than_2_7
from .utils import is_hqq_available, is_optimum_quanto_available, logging


Expand Down Expand Up @@ -537,10 +538,10 @@ def batch_select_indices(self, indices: torch.Tensor):

class OffloadedCache(DynamicCache):
"""
A drop-in replacement for DynamicCache that conserves GPU memory at the expense of more CPU memory.
A drop-in replacement for DynamicCache that conserves accelerator(GPU, XPU) memory at the expense of more CPU memory.
Useful for generating from models with very long context.

In addition to the default CUDA stream, where all forward() computations happen,
In addition to the default accelerator stream, where all forward() computations happen,
this class uses another stream, the prefetch stream, which it creates itself.
Since scheduling of operations on separate streams happens independently, this class uses
the prefetch stream to asynchronously prefetch the KV cache of layer k+1 when layer k is executing.
Expand All @@ -549,17 +550,24 @@ class OffloadedCache(DynamicCache):
"""

def __init__(self) -> None:
if not torch.cuda.is_available():
raise RuntimeError("OffloadedCache can only be used with a GPU")
if not (torch.cuda.is_available() or (is_torch_greater_or_equal_than_2_7 and torch.xpu.is_available())):
raise RuntimeError(
"OffloadedCache can only be used with a GPU"
+ (" or XPU" if is_torch_greater_or_equal_than_2_7 else "")
)

super().__init__()
self.original_device = []
self.prefetch_stream = torch.cuda.Stream()
self.prefetch_stream = None
self.prefetch_stream = torch.Stream() if is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
self.beam_idx = None # used to delay beam search operations

def prefetch_layer(self, layer_idx: int):
"Starts prefetching the next layer cache"
if layer_idx < len(self):
with torch.cuda.stream(self.prefetch_stream):
with (
self.prefetch_stream if is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream)
):
# Prefetch next layer tensors to GPU
device = self.original_device[layer_idx]
self.key_cache[layer_idx] = self.key_cache[layer_idx].to(device, non_blocking=True)
Expand All @@ -577,7 +585,10 @@ def __getitem__(self, layer_idx: int) -> List[Tuple[torch.Tensor]]:
"Gets the cache for this layer to the device. Prefetches the next and evicts the previous layer."
if layer_idx < len(self):
# Evict the previous layer if necessary
torch.cuda.current_stream().synchronize()
if is_torch_greater_or_equal_than_2_7:
torch.accelerator.current_stream().synchronize()
else:
torch.cuda.current_stream().synchronize()
self.evict_previous_layer(layer_idx)
# Load current layer cache to its original device if not already there
original_device = self.original_device[layer_idx]
Expand Down
1 change: 1 addition & 0 deletions src/transformers/pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

parsed_torch_version_base = version.parse(version.parse(torch.__version__).base_version)

is_torch_greater_or_equal_than_2_7 = parsed_torch_version_base >= version.parse("2.7")
is_torch_greater_or_equal_than_2_4 = parsed_torch_version_base >= version.parse("2.4")
is_torch_greater_or_equal_than_2_3 = parsed_torch_version_base >= version.parse("2.3")
is_torch_greater_or_equal_than_2_2 = parsed_torch_version_base >= version.parse("2.2")
Expand Down
36 changes: 28 additions & 8 deletions tests/utils/test_cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
require_non_xpu,
require_read_token,
require_torch,
require_torch_accelerator,
require_torch_gpu,
require_torch_multi_gpu,
slow,
Expand All @@ -48,7 +49,10 @@
StaticCache,
convert_and_export_with_cache,
)
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_3
from transformers.pytorch_utils import (
is_torch_greater_or_equal_than_2_3,
is_torch_greater_or_equal_than_2_7,
)


@require_torch
Expand Down Expand Up @@ -230,7 +234,7 @@ def test_static_cache_exportability(self):
self.assertEqual(n_static_value_caches, model.config.num_hidden_layers)


@require_torch_gpu
@require_torch_accelerator
@slow
class CacheIntegrationTest(unittest.TestCase):
def test_dynamic_cache_hard(self):
Expand Down Expand Up @@ -542,13 +546,17 @@ def test_static_cache_extra_left_padding(self, cache_implementation):
def test_static_cache_beam_search(self):
pass

@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
"""Tests that OffloadedCache produces the same result as the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device

if not is_torch_greater_or_equal_than_2_7 and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")

input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
Expand All @@ -566,13 +574,17 @@ def test_offloaded_cache_equivalent_to_dynamic_cache(self):
for original_output, offloaded_output in zip(original_outputs, offloaded_outputs):
assert torch.all(original_output == offloaded_output).item()

@require_torch_gpu
@require_torch_accelerator
def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
"""Tests that OffloadedCache uses less memory than the default DynamicCache"""
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16)
device = model.device

if not is_torch_greater_or_equal_than_2_7 and device.type == "xpu":
self.skipTest(reason="This test requires torch >= 2.7 to run on xpu.")

input_text = "Fun fact:"
inputs = tokenizer(input_text, return_tensors="pt").to(device)
common = {
Expand All @@ -585,12 +597,20 @@ def test_offloaded_cache_uses_less_memory_than_dynamic_cache(self):
}
original = GenerationConfig(**common)
offloaded = GenerationConfig(cache_implementation="offloaded", **common)
torch.cuda.reset_peak_memory_stats(device)

torch_accelerator_module = None
if device.type == "cuda":
torch_accelerator_module = torch.cuda
elif device.type == "xpu":
torch_accelerator_module = torch.xpu

torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=original, **inputs)
original_peak_memory = torch.cuda.max_memory_allocated(device)
torch.cuda.reset_peak_memory_stats(device)
original_peak_memory = torch_accelerator_module.max_memory_allocated(device)
torch_accelerator_module.reset_peak_memory_stats(device)
model.generate(generation_config=offloaded, **inputs)
offloaded_peak_memory = torch.cuda.max_memory_allocated(device)
offloaded_peak_memory = torch_accelerator_module.max_memory_allocated(device)
print(f"original_peak_memory: {original_peak_memory}, offloaded_peak_memory: {offloaded_peak_memory}")
assert offloaded_peak_memory < original_peak_memory

@require_torch_gpu
Expand Down