-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Implement TeaCache #12652
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Implement TeaCache #12652
Conversation
|
Work done
Waiting for feedback and review :) |
|
Hi @sayakpaul @dhruvrnaik any updates? |
|
@LawJarp-A sorry about the delay on our end. @DN6 will review it soon. |
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
|
Hi @LawJarp-A I think we would need TeaCache to be implemented in a model agnostic way in order to merge the PR. The First Block Cache implementation is a good reference for this. |
Yep @DN6 , I agree, I wanted to first implement it just for a single model and get feedback on that before I work on Model agnostic full implementation. I'm sort of working on it, didn't push it yet. I'll take a look at First block cache for reference as well. |
|
@DN6 updated it in a more model agnostic way. |
…th auto-detection
|
Added multi model support, testing it thoroughly though. |
|
Hi @DN6 @sayakpaul
In the meantime any feedback would be appreciated |
|
Thanks @LawJarp-A!
You can refer to #12569 for testing
Yes, I think that is informative for users. |
sayakpaul
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some initial feedback. Most important question is it seems like we need to craft different logic based on different model? Can we not keep it model agnostic?
|
|
||
| _TEACACHE_HOOK = "teacache" | ||
|
|
||
| # Model-specific polynomial coefficients from TeaCache paper/reference implementations |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we know if these are just model-agnostic or there's something other dependencies as well (for example num_inference_steps, guidance_scale, etc.)?
Also, can we add a calibration step similar to #12648 so that users can log these coefficients for other models?
|
I am trying to think if ways we can avoid having the forward model for each model now. Initially that seemed like th ebe
t was fine when I wrote for flux, but lumina needed multi stage preprocessing. |
…torch.compile support, and clean up coefficient flow Signed-off-by: Prajwal A <[email protected]>
Signed-off-by: Prajwal A <[email protected]>
Signed-off-by: Prajwal A <[email protected]>
|
@sayakpaul |
|
@sayakpaul @DN6 I got the core logic working, and tested it for model my GPU can handle The current implementation puts all model handlers in a single Potential refactor: Registry + Handler pattern Each handler self-registers and encapsulates its logic: # handlers/flux.py
from .base import BaseTeaCacheHandler
from ..registry import register_handler
@register_handler("Flux", "FluxKontext")
class FluxHandler(BaseTeaCacheHandler):
coefficients = [4.98651651e02, -2.83781631e02, ...]
def extract_modulated_input(self, module, hidden_states, temb):
return module.transformer_blocks[0].norm1(hidden_states, emb=temb)[0]
def handle_forward(self, module, *args, **kwargs):
# FLUX-specific forward with ControlNet, LORA, etc.
...# registry.py
_HANDLER_REGISTRY = {}
def register_handler(*model_names):
def decorator(cls):
for name in model_names:
_HANDLER_REGISTRY[name] = cls
return cls
return decorator
def get_handler(module) -> BaseTeaCacheHandler:
for name, handler_cls in _HANDLER_REGISTRY.items():
if name in module.__class__.__name__:
return handler_cls()
raise ValueError(f"No TeaCache handler for {module.__class__.__name__}")This is similar to how attention processors and schedulers are organized. Happy to refactor if you think it's worth it, or we can keep it simple like now. Since this has proven a bit more of a challenge to integrate than I thought xD would be happy to know if you guys have some ideas. |
|
Hey @DN6 @sayakpaul , any updates :) |
Signed-off-by: Prajwal A <[email protected]>
|
@sayakpaul @DN6 checking in again :) |
DN6
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some high level feedback on the design. The control flow is hard to follow as it switches between the hook object and adapter. The adapters themselves are thin wrappers around a modified forward function, so it would be better to just define them as standalone functions. e.g.
def _flux_forward(
state: "TeaCacheState", # pass the state to the function not the hook object
coefficients: List[float],
rel_l1_thresh: float,
module: torch.nn.Module,
hidden_states: torch.Tensor,
timestep: torch.Tensor,
pooled_projections: torch.Tensor,
encoder_hidden_states: torch.Tensor,
txt_ids: torch.Tensor,
img_ids: torch.Tensor,
return_dict: bool = True,
**kwargs,
):
if _should_use_cache(state, modulated_inp, coefficients, rel_l1_thresh)
hidden_states = _apply_cached_residual(state, hidden_states, modulated_inp)
else:
# run compute
_update_cache(state, hidden_states, original_hidden_states, modulated_inp)Since we're hooking the top level forward of the model, we can map this forward function using the class name during hook initialization.
def initialize_hook(self, module):
"""Initialize hook with model-specific configuration."""
model_config = _MODEL_CONFIG.get(module.__name__)
if model_config is None:
raise ValueError
if self.config.coefficients is not None:
self.coefficients = self.config.coefficients
else:
self.coefficients = model_config["coefficients"]
# Initialize state
self.state_manager = StateManager(TeaCacheState)
self.forward_fn = model_config["forward_func"]
return moduleWhere _MODEL_CONFIG is just a mapping for the forward functions and coefficients
_MODEL_CONFIG = {
"FluxTransformer2DModel": {
"forward_func": _flux_forward,
"coefficients": [4.98651651e02, -2.83781631e02, 5.58554382e01, -3.82021401e00, 2.64230861e-01],
},
}Similarly, the methods defined in the hook object could also be turned into utility functions.
def _compute_rescaled_distance(rel_distance: float, coefficients: List[float]) -> float:
return (
coefficients[0] * rel_distance**4
+ coefficients[1] * rel_distance**3
+ coefficients[2] * rel_distance**2
+ coefficients[3] * rel_distance
+ coefficients[4]
)
def _should_use_cache(state: "TeaCacheState", ...):
# Return True or False based on whether to use cache.
return
def _update_cache(state: "TeaCacheState)
return
def _apply_cached_residual(
state: "TeaCacheState", input_base: torch.Tensor, modulated_inp: torch.Tensor
) -> torch.Tensor:
"""
Apply cached residual to input (fast path).
"""
output = input_base + state.previous_residual
state.previous_modulated_input = modulated_inp
state.cnt += 1
return outputLet's remove passing cache_fn and compute_fn between the hook and the adapter. Use operations directly on the cache state + globally available utility methods. We can also remove the modulation extractors and move that logic into the model specific forward functions.
src/diffusers/hooks/teacache.py
Outdated
| ) | ||
| if self.rel_l1_thresh < 0.05: | ||
| import warnings | ||
| warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use logger.warning
|
|
||
| registry._set_context(None) | ||
|
|
||
| def enable_teacache(self, rel_l1_thresh: float = 0.2, num_inference_steps: int = None, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cacheing should only be enabled through enable_cache and passing the relevant config. Cache specific enabling is not supported.
| pipe.to("cuda") | ||
| # Enable TeaCache with auto-detection (1.5x speedup) | ||
| pipe.transformer.enable_teacache(rel_l1_thresh=0.2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should be
pipe.transformer.enable_cache(...)We don't enable specific cacheing methods directly
| logger.info(f"TeaCache: Using {state.num_steps} inference steps") | ||
|
|
||
| def initialize_hook(self, module): | ||
| self.state_manager.set_context("teacache") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cache context is typically set in the denoising loop? I think in this case, both conditional and unconditional branches would write to the same cache state when using CFG.
|
|
||
| def _flux_modulated_input_extractor(module, hidden_states, timestep_emb): | ||
| """Extract modulated input for FLUX models.""" | ||
| return module.transformer_blocks[0].norm1(hidden_states, emb=timestep_emb)[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think these extractor functions can be folded into the adapter functions of each model. They're thin wrappers around a single line of code.
| self.model_type = None | ||
|
|
||
| @staticmethod | ||
| def _create_rescale_func(coefficients): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to create a rescale func? If we have coefficients set, we should be able to just call the function directly?
def rescale_fn(self):
return self.coefficients[0] * x**4 + self.coefficients[1] * x**3 + self.coefficients[2] * x**2 + self.coefficients[3] * x + self.coefficients[4]|
Thanks for the feedback @DN6 |
What does this PR do?
What is TeaCache?
TeaCache (Timestep Embedding Aware Cache) is a training-free caching technique that speeds up diffusion model inference by 1.5x-2.6x by reusing transformer block computations when consecutive timestep embeddings are similar.
Architecture
Integrates with existing
HookRegistryandCacheMixinpatterns in diffusers.Supported Models
Benchmark Results (FLUX.1-schnell, 20 steps, 512x512)
Benchmark Results (Lumina2, 28 steps, 512x512)
Benchmark Results (CogVideoX-2b, 50 steps, 720x720, 49 frames)
Test Hardware: NVIDIA A100-SXM4-40GB
Framework: Diffusers with TeaCache hooks
All tests: Same seed (42) for reproducibility
Pros & Cons
Pros:
enable_teacache()Cons:
Usage
Files Changed
src/diffusers/hooks/teacache.py- Core implementationsrc/diffusers/models/cache_utils.py- CacheMixin integrationtests/hooks/test_teacache.py- Unit testsFixes # (issue)
#12589
#12635
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @yiyixuxu