Skip to content

Conversation

@gante
Copy link
Contributor

@gante gante commented Apr 29, 2025

What does this PR do?

The main purpose of this PR is to convert a few slow tests targeted at one cache implementation into fast tests that run on ALL cache implementations.

Secondarily, makes RUN_SLOW=1 py.test tests/utils/test_cache_utils.py green 🟢 These tests also become much, much faster (3 mins -> 1 min, on my machine), despite covering a larger number of features.

This is a follow up to #37684, which paved the way for this PR. After this PR is merged, I can go back to #37394 and properly test things!

👉 torch.compile was benchmarked with gemma2/hybrid and qwen3/static, no speed regressions.
👉 no regressions in RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py

@github-actions github-actions bot marked this pull request as draft April 29, 2025 18:18
@github-actions
Copy link
Contributor

Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the Ready for review button (at the bottom of the PR page). This will assign reviewers and trigger CI.


class SinkCache(Cache):
"""
Deprecated.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SinkCache has been broken on some edge cases for over a year, the issues are non-trivial to fix, and it is no longer relevant -- we can achieve a similar effect with a few other flags. See deprecation warning below.

@gante gante marked this pull request as ready for review April 29, 2025 18:21
slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0)
cache_position = cache_position.clamp(0, self.max_cache_len - 1)
to_shift = cache_position >= self.max_cache_len - 1
to_shift = cache_position > self.max_cache_len - 1
Copy link
Contributor Author

@gante gante Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Off by one: we were applying the shifting update one token too early. This applies on the last token when we initialize the sliding window cache with the exact size of the generation (e.g. with model.generate(..., cache_implementation="sliding_window")).

This effectively means our models were micro-underperforming with sliding window caches, more specifically on the last generated token :D One of the new tests caught this issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On first glance, this likely fixes the issue(s) raised in #37574 👀

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change is wrong in general and leads to garbage generation on sequence > sliding window! I am opening a PR to revert with examples 😉 What you observed is the fact that prefill and later stages should be treated separately in terms of the states they return

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Cyrilvallez You should give #37972 a look before :D

"config and it's not set to None."
)
self.max_cache_len = max_cache_len
self._sliding_window_max_len = min(config.sliding_window, max_cache_len)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HybridCache had the right pattern, but some of the other hybrid caches did not: generation was crashing if we tried to generate a max length < sliding window length. Caught by one of the new tests.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@require_torch_accelerator
class CacheIntegrationTest(unittest.TestCase):
"""Cache tests that require loading models"""
"""Fast cache integration tests that share the same small model"""
Copy link
Contributor Author

@gante gante Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Separated into two classes, to make best use of setUpClass. Loading the model is the most costly part of these tests, and we only do it once.


# DynamicCache and the legacy cache format should be equivalent
set_seed(0)
gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256)
Copy link
Contributor Author

@gante gante Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the default is now DynamicCache(), the two generate calls in this test were the same

self.assertEqual(decoded[0], expected_text)

@slow
def test_dynamic_cache_batched(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapted into CacheIntegrationTest

self.assertListEqual(decoded, expected_text)

@slow
def test_dynamic_cache_beam_search(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapted into CacheIntegrationTest

self.assertListEqual(decoded, expected_text)

@slow
def test_hybrid_cache_n_sequences(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

redundant with the tests in CacheIntegrationTest (more specifically, test_cache_batched and test_cache_beam_search)

@require_non_xpu
@require_gptq
@slow
def test_sink_cache_hard(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test was broken and SinkCache is being deprecated

self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network"))

@slow
def test_sink_cache_iterative_prompts(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

test was broken and SinkCache is being deprecated

self.assertListEqual(decoded, EXPECTED_GENERATION)

@slow
def test_dynamic_cache_extra_left_padding(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapted into CacheIntegrationTest

self.assertListEqual(decoded, EXPECTED_GENERATION)

@slow
def test_static_cache_extra_left_padding(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

adapted into CacheIntegrationTest


@require_torch_accelerator
@slow
def test_offloaded_cache_equivalent_to_dynamic_cache(self):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we implicitly test this in CacheIntegrationTest

responses.append(response)

EXPECTED_DECODED_TEXT = [
"You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week",
Copy link
Contributor Author

@gante gante Apr 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we checkout to the commit that added this test, we get a different output 👀 possibly due to different hardware/software? (anyway, I don't think it's worth to pin the exact cause)

# on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped.
with CaptureStderr() as cap:
model.generate(**inputs, max_new_tokens=2, cache_implementation="static")
self.assertEqual(cap.err, "")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

failing on main if we have kernels installed, this change makes the test green regardless of the installed packages

self.skipTest("Quanto is not available")

if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think offloaded_hybrid_chunked + beam_search is worth the dive for now 🤔

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope agree with you!


from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, DynamicCache
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(same diff on all models)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice! Thanks 🤗
Would be nice to have a fast test for the HybridChunked to make sure compile is fine using a dummy gemma2 model maybe?

TP is also an option to test 👀 but more of a TODO later!


from ...activations import ACT2FN
from ...cache_utils import Cache, DynamicCache, StaticCache
from ...cache_utils import Cache, DynamicCache
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

self.skipTest("Quanto is not available")

if cache_implementation == "offloaded_hybrid_chunked":
# TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nope agree with you!

@gante
Copy link
Contributor Author

gante commented Apr 30, 2025

@ArthurZucker yeah, generalist cache + compile tests will be up next! :D

@gante gante merged commit 1b22290 into huggingface:main Apr 30, 2025
20 checks passed
@gante gante deleted the test_all_caches branch April 30, 2025 14:37
@Cyrilvallez Cyrilvallez mentioned this pull request May 9, 2025
zucchini-nlp pushed a commit to zucchini-nlp/transformers that referenced this pull request May 14, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants