-
Notifications
You must be signed in to change notification settings - Fork 31.7k
[tests] Test all cache implementations #37873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. The CI will be paused while the PR is in draft mode. When it is ready for review, please click the |
|
|
||
| class SinkCache(Cache): | ||
| """ | ||
| Deprecated. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SinkCache has been broken on some edge cases for over a year, the issues are non-trivial to fix, and it is no longer relevant -- we can achieve a similar effect with a few other flags. See deprecation warning below.
| slicing = torch.ones(self.max_cache_len, dtype=torch.long, device=value_states.device).cumsum(0) | ||
| cache_position = cache_position.clamp(0, self.max_cache_len - 1) | ||
| to_shift = cache_position >= self.max_cache_len - 1 | ||
| to_shift = cache_position > self.max_cache_len - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Off by one: we were applying the shifting update one token too early. This applies on the last token when we initialize the sliding window cache with the exact size of the generation (e.g. with model.generate(..., cache_implementation="sliding_window")).
This effectively means our models were micro-underperforming with sliding window caches, more specifically on the last generated token :D One of the new tests caught this issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
On first glance, this likely fixes the issue(s) raised in #37574 👀
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This change is wrong in general and leads to garbage generation on sequence > sliding window! I am opening a PR to revert with examples 😉 What you observed is the fact that prefill and later stages should be treated separately in terms of the states they return
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Cyrilvallez You should give #37972 a look before :D
| "config and it's not set to None." | ||
| ) | ||
| self.max_cache_len = max_cache_len | ||
| self._sliding_window_max_len = min(config.sliding_window, max_cache_len) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
HybridCache had the right pattern, but some of the other hybrid caches did not: generation was crashing if we tried to generate a max length < sliding window length. Caught by one of the new tests.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
| @require_torch_accelerator | ||
| class CacheIntegrationTest(unittest.TestCase): | ||
| """Cache tests that require loading models""" | ||
| """Fast cache integration tests that share the same small model""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Separated into two classes, to make best use of setUpClass. Loading the model is the most costly part of these tests, and we only do it once.
|
|
||
| # DynamicCache and the legacy cache format should be equivalent | ||
| set_seed(0) | ||
| gen_out_legacy = model.generate(**inputs, do_sample=True, max_new_tokens=256) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the default is now DynamicCache(), the two generate calls in this test were the same
| self.assertEqual(decoded[0], expected_text) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_batched(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, expected_text) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_beam_search(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, expected_text) | ||
|
|
||
| @slow | ||
| def test_hybrid_cache_n_sequences(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
redundant with the tests in CacheIntegrationTest (more specifically, test_cache_batched and test_cache_beam_search)
| @require_non_xpu | ||
| @require_gptq | ||
| @slow | ||
| def test_sink_cache_hard(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test was broken and SinkCache is being deprecated
| self.assertTrue(decoded[0].endswith("to perform a variety of tasks. The Transformer is a neural network")) | ||
|
|
||
| @slow | ||
| def test_sink_cache_iterative_prompts(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
test was broken and SinkCache is being deprecated
| self.assertListEqual(decoded, EXPECTED_GENERATION) | ||
|
|
||
| @slow | ||
| def test_dynamic_cache_extra_left_padding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adapted into CacheIntegrationTest
| self.assertListEqual(decoded, EXPECTED_GENERATION) | ||
|
|
||
| @slow | ||
| def test_static_cache_extra_left_padding(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adapted into CacheIntegrationTest
|
|
||
| @require_torch_accelerator | ||
| @slow | ||
| def test_offloaded_cache_equivalent_to_dynamic_cache(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we implicitly test this in CacheIntegrationTest
| responses.append(response) | ||
|
|
||
| EXPECTED_DECODED_TEXT = [ | ||
| "You are a helpful assistant. Help me to write a blogpost about travelling.\n\nTraveling is an enriching experience that broadens our horizons and exposes us to new cultures, landscapes, and people. Whether it's a week", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we checkout to the commit that added this test, we get a different output 👀 possibly due to different hardware/software? (anyway, I don't think it's worth to pin the exact cause)
| # on `main`, prior to #36543, this would send stderr messages about cuda graphs being skipped. | ||
| with CaptureStderr() as cap: | ||
| model.generate(**inputs, max_new_tokens=2, cache_implementation="static") | ||
| self.assertEqual(cap.err, "") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
failing on main if we have kernels installed, this change makes the test green regardless of the installed packages
| self.skipTest("Quanto is not available") | ||
|
|
||
| if cache_implementation == "offloaded_hybrid_chunked": | ||
| # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think offloaded_hybrid_chunked + beam_search is worth the dive for now 🤔
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope agree with you!
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache, StaticCache | ||
| from ...cache_utils import Cache, DynamicCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(same diff on all models)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice! Thanks 🤗
Would be nice to have a fast test for the HybridChunked to make sure compile is fine using a dummy gemma2 model maybe?
TP is also an option to test 👀 but more of a TODO later!
|
|
||
| from ...activations import ACT2FN | ||
| from ...cache_utils import Cache, DynamicCache, StaticCache | ||
| from ...cache_utils import Cache, DynamicCache |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
| self.skipTest("Quanto is not available") | ||
|
|
||
| if cache_implementation == "offloaded_hybrid_chunked": | ||
| # TODO (joao, cyril): something is off with `offloaded_hybrid_chunked` aka `OffloadedHybridCache`: the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nope agree with you!
|
@ArthurZucker yeah, generalist cache + compile tests will be up next! :D |
What does this PR do?
The main purpose of this PR is to convert a few slow tests targeted at one cache implementation into fast tests that run on ALL cache implementations.
Secondarily, makes
RUN_SLOW=1 py.test tests/utils/test_cache_utils.pygreen 🟢 These tests also become much, much faster (3 mins -> 1 min, on my machine), despite covering a larger number of features.This is a follow up to #37684, which paved the way for this PR. After this PR is merged, I can go back to #37394 and properly test things!
👉 torch.compile was benchmarked with gemma2/hybrid and qwen3/static, no speed regressions.
👉 no regressions in
RUN_SLOW=1 py.test tests/models/llama/test_modeling_llama.py