-
Notifications
You must be signed in to change notification settings - Fork 31.7k
Fix: warn and set use_cache=False if activation checkpointing via Accelerate is enabled
#36610
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi 👋, thank you for opening this pull request! The pull request is converted to draft by default. When it is ready for review, please click the |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Note: failing test seems to be unrelated and failing else-where as well. |
|
This looks like an interaction between generation caching and activation checkpointing so cc @muellerzr @gante |
From PyTorch folks: |
ArthurZucker
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Feel free to merge @muellerzr if it works for you!
muellerzr
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes indeed :)
|
@S1ro1 can you resolve the conflicts and then @ArthurZucker can likely merge (since unrelated failures) |
|
@ArthurZucker EDIT: Failing Until then let's not merge. |
|
The underlying issue with cache and activation checkpointing has been fixed as an side-effect of #36311, therefore no need for merge. Closing now. |
What does this PR do?
Fixes #34928 (partially), the underlying issue also references many other with similar issue.
What's the issue
We have many checks that unset
use_cachein-case of gradient checkpointing, however none that do so for activation checkpointing.This becomes an issue because we do not really have a way to detect that the model uses AC (as to my best knowledge). That's why we can only partially mitigate this when used together with
Accelerate, because that setsFSDP_ACTIVATION_CHECKPOINTINGflag if AC is used. To my best knowledge, other ways to apply AC do not create any side effects that enable us to detect it.Why in
PretrainedModel.__init__and not lower?Doesn't really matter if model is created via
from_pretrained,from_configor with its constructor, this issue affects all of it. Also it affects any model that usesDynamicCacheso basically any generation model, can't warn everywhere.Minimal Repro
This would fail before on:
As Qwen is one of the models that sets
use_cache=truecc @muellerzr
Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.