-
Notifications
You must be signed in to change notification settings - Fork 6k
Stable Audio integration #8716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Audio integration #8716
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
thanks for the PR! Overall, it looks pretty aligned with the diffuser's design! Here is my initial feedback:
|
Hey @yiyixuxu, thanks for the feedback here! I think the main reason for the separate projection model is that
Is this something that we'd want to do in the transformers ? IMO, no, but happy to change the way it's implemented ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for quickly making it ready for reviews.
seconds_end_hidden_states=seconds_end_hidden_states, | ||
) | ||
|
||
@maybe_allow_in_graph |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice ❤️
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.HunyuanDiT2DModel.fuse_qkv_projections | ||
def fuse_qkv_projections(self): | ||
""" | ||
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value) | ||
are fused. For cross-attention modules, key and value projection matrices are fused. | ||
|
||
<Tip warning={true}> | ||
|
||
This API is 🧪 experimental. | ||
|
||
</Tip> | ||
""" | ||
self.original_attn_processors = None | ||
|
||
for _, attn_processor in self.attn_processors.items(): | ||
if "Added" in str(attn_processor.__class__.__name__): | ||
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") | ||
|
||
self.original_attn_processors = self.attn_processors | ||
|
||
for module in self.modules(): | ||
if isinstance(module, Attention): | ||
module.fuse_projections(fuse=True) | ||
|
||
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections | ||
def unfuse_qkv_projections(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All of these optimizations can come after we have a basic implementation ready that is matching the original outputs. It's easier to review as well. WDYT?
""" | ||
self.vae.disable_slicing() | ||
|
||
def enable_model_cpu_offload(self, gpu_id=0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we need to implement it separately here? Can't we specify a model_cpu_offload_seq
like this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I copied this snippet from another pipeline, I'm removing it and add model_cpu_offload_seq
as you proposed
encoder_hidden_states: torch.FloatTensor = None, | ||
global_hidden_states: torch.FloatTensor = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How are these two different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The decoder here has self-attention and cross-attention layers:
encoder_hidden_states
is used in the cross-attention layer whereas the global_hidden_states
is simply prepended to the hidden_states
before being passed to the attention layer.
@ylacombe IMO, ideally, we do want to move the projection layers in transformers, but since the original implementation is implemented this way, let's keep it this way for now. I can help look into this later. One way I think we can go about this is to make the if self.do_classifier_free_guidance:
audio_end_in_s = torch.cat([audio_end_in_s, audio_end_in_s], dim=0)
elif ..:
neg_audio_end_in_s = torch.tensor([0])
audio_end_in_s = torch.cat([neg_audio_end_in_s, audio_end_in_s], dim=0) this way when these argument reach the transformer, it already contain info about CFG. But I'm just making it up here, I wouldn't know if it would work. so let's just don't worry about it and continue with your implementation :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @sayakpaul and @yiyixuxu !
While I figure out how to reconcile schedulers, I've left some comments on implementation choice for reference.
I'll address the current comments once we agree on the scheduler!
Thanks for your help!
scripts/convert_stable_audio.py
Outdated
|
||
|
||
# scheduler | ||
scheduler = DPMSolverMultistepScheduler(solver_order=2, algorithm_type="sde-dpmsolver++", use_exponential_sigmas=True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I still need to find the right scheduler for this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I had to add a new variant of GEGLU here
kv_heads (`int`, *optional*, defaults to `None`): | ||
The number of key and value heads to use for multi-head attention. Defaults to `heads`. | ||
If `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use | ||
Multi Query Attention (MQA) otherwise GQA is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've also added kv_heads
for Grouped-Query Attention (GQA) support, it only works with StableAudioAttnProcessor2_0
for now.
Let me know if that works for you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have kv_heads
:
kv_heads: Optional[int] = None, |
Possible leverage that?
StableAudioAttnProcessor2_0
could perhaps be renamed to GroupedQueryAttnProcessor2_0
?
Also, could
class LuminaAttnProcessor2_0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh right, kv_heads
''ve just been added 3 days ago in #8652!
Contrarily to Lumina we don't use a s normalization layer
which is smthg I've never seen before!
(I can rename the stable audio one to GroupedQueryAttnProcessor2_0
if necessary)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StableAudioAttnProcessor2_0
should be the way to go
Even with grouped query attention, you can have different configurations (lumina has qk norm and an unusual way of applying it, you can have it with or without rotary embeddings etc), and no need to complicate things by forcing them to use the same attention processor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stable Audio uses a partial rotary position embedding and also performs some operations differently
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src/diffusers/models/activations.py
Outdated
class GLU(nn.Module): | ||
r""" | ||
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. | ||
It's similar to `GEGLU` but uses SiLU / Swish instead of GeLU. | ||
|
||
Parameters: | ||
dim_in (`int`): The number of channels in the input. | ||
dim_out (`int`): The number of channels in the output. | ||
act_fn (str): Name of activation function used. | ||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm still using this in the FeedForward
, but happy to move everything to the modeling file if @yiyixuxu agrees!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. Since that Feedforward is very focused on GELU and GeGELU I thought it might be a better option to have its own class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it is very reasonable to use this in FeedForward, no need to move
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The feedforward does not have GeLU though.
kv_heads (`int`, *optional*, defaults to `None`): | ||
The number of key and value heads to use for multi-head attention. Defaults to `heads`. | ||
If `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use | ||
Multi Query Attention (MQA) otherwise GQA is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We already have kv_heads
:
kv_heads: Optional[int] = None, |
Possible leverage that?
StableAudioAttnProcessor2_0
could perhaps be renamed to GroupedQueryAttnProcessor2_0
?
Also, could
class LuminaAttnProcessor2_0: |
src/diffusers/models/embeddings.py
Outdated
Apply partial rotary embeddings (Wang et al. GPT-J) to input tensors using the given frequency tensor. This function applies rotary embeddings | ||
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are | ||
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting | ||
tensors contain rotary embeddings and are returned as real tensors. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit): Provide some references that make use of partial rotary embeddings?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! looking great!
I did an initial review and left some comments, mainly focused on the changes introduced into non-stable audio files for this review, will do a second round next week!
kv_heads (`int`, *optional*, defaults to `None`): | ||
The number of key and value heads to use for multi-head attention. Defaults to `heads`. | ||
If `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use | ||
Multi Query Attention (MQA) otherwise GQA is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
StableAudioAttnProcessor2_0
should be the way to go
Even with grouped query attention, you can have different configurations (lumina has qk norm and an unusual way of applying it, you can have it with or without rotary embeddings etc), and no need to complicate things by forcing them to use the same attention processor
src/diffusers/models/embeddings.py
Outdated
else: | ||
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] | ||
return freqs_cis | ||
|
||
def apply_partial_rotary_emb( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's use apply_rotary_emb
to support this, we can do the "apply partial" part inside the attention processor
- we can split x into x_ro_rorate and x_unrotated before calling
apply_rotary_emb
- only pass the
x_to_rotate
toapply_rotary_emb
- do this on the output of
apply_rotary_emb
->out = torch.cat((out, x_unrotated), dim = -1)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a small nuance in both methods:
apply_rotary_emb
creates x_real and x_imag by reshaping like this:x.reshape(*x.shape[:-1], -1, 2).unbind(-1)
apply_partial_rotary_emb
reshapes in a different way (last two dimensions are swapped) :x_to_rotate.reshape(*x_to_rotate.shape[:-1], 2, -1).unbind(dim=-2)
The resulting tensors are totally different. Happy to solve this a different way , I could add a boolean to reshape one way or another, WDYT ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good, maybe add a use_real_unbind_dim
and defaults to -1
?
@@ -107,6 +110,7 @@ def __init__( | |||
lower_order_final: bool = True, | |||
euler_at_final: bool = False, | |||
final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" | |||
noise_preconditioning_strategy: str = "log", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let me know what you prefer!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it seems that it's not easy to make it work with the non-EDM version?
None that I've seen!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your hard work. I know you have been working quite hard on this. My comments should feel minor. I think we're nearing merge.
if "snake" in new_key: | ||
value = value.unsqueeze(0).unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🐍
nice param name
@@ -123,6 +123,28 @@ def forward(self, hidden_states, *args, **kwargs): | |||
return hidden_states * self.gelu(gate) | |||
|
|||
|
|||
class SwiGLU(nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Finally the OG FF.
elif activation_fn == "swiglu": | ||
act_fn = SwiGLU(dim, inner_dim, bias=bias) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this should be okay as we don't name our FeedForward in a way that indicates that it's restrictive toward SwiGLU.
kv_heads (`int`, *optional*, defaults to `None`): | ||
The number of key and value heads to use for multi-head attention. Defaults to `heads`. If | ||
`kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi | ||
Query Attention (MQA) otherwise GQA is used. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay this is a doc-related change. Specifically, you are adding this entry to the doc-string. Thanks for doing that!
return float_embeds | ||
|
||
|
||
class StableAudioProjectionModel(ModelMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu sorry for the re-iteration. But what are some strong reasons to add it as a ModelMixin
and not simply an nn.Module
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I summarized them here.
To re-iterate, CFG is applied on the output of the projection model, in a non-straightforward way. I'd rather keep the CFG logic in the pipeline.
>>> # Peak normalize, clip, convert to int16, and save to file | ||
>>> output = ( | ||
... audio[0] | ||
... .to(torch.float32) | ||
... .div(torch.max(torch.abs(audio[0]))) | ||
... .clamp(-1, 1) | ||
... .mul(32767) | ||
... .to(torch.int16) | ||
... .cpu() | ||
... ) | ||
>>> torchaudio.save("hammer.wav", output, pipe.vae.sampling_rate) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@yiyixuxu does it not make sense to have a postprocessing util for this stuff? This code block is large enough to warrant such a block IMO.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've removed that post-processing, turns out we don't need it
if output_type == "np": | ||
audio = audio.cpu().float().numpy() | ||
|
||
if not return_dict: | ||
return (audio,) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should really post-process the audio here like we do in the other image pipelines.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not entirely sure that the post-processing done in the code snippet is necessary or is just a way to post-process to save with torchaudio. In the former case, we'd add it to a post-processing function, in the later, we don't need to, let me verify.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, if I understand correctly, the post-processing depends on which library you want to save the generated audio with?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Exactly ! turns out we don't need that post-processing if we save with something else, I've updated the example snippet accordingly
Co-authored-by: YiYi Xu <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! looking great! I left one comment about making a new scheduler
should be pretty easy to do, no? let me know
we can merge this after that!
def apply_partial_rotary_emb( | ||
self, | ||
x: torch.Tensor, | ||
freqs_cis: Tuple[torch.Tensor], | ||
) -> torch.Tensor: | ||
from .embeddings import apply_rotary_emb | ||
|
||
rot_dim = freqs_cis[0].shape[-1] | ||
x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] | ||
|
||
x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) | ||
|
||
out = torch.cat((x_rotated, x_unrotated), dim=-1) | ||
return out |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we move it to where it's used inside the __call__
? it is just two lines wrapped around the apply_rotary_emb
method
if is_torchsde_available(): | ||
from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler | ||
|
||
|
||
class EDMDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make a CosineDPMSolverMultistepScheduler
and move all the changes there? it is ok to only support BrownianTree there
it is not EDM and the DPMSolverMultistepScheduler
is super bloated already, I think the easiest way is to make a new one!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Of course, I'm not sure about how to best describe it in docstrings and docs though!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks! I updated it here #8716 (comment)
The scheduler itself is not something entirely new - but this combination was not used in any other models at least in diffusers I think, and the "cosine schedule" part is the only part that's not in the DPM scheduler so let's just make a simple note of that
src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py
Outdated
Show resolved
Hide resolved
Thank you so much @ylacombe for your hard work here. Navigating through 1e14 comments and addressing them like you did is NO SMALL FEAT. Thank you once again! |
Thank you @ylacombe! Is there an example of how to use initial_audio_waveforms somewhere? Is that for extending or zero-shot generation? |
* WIP modeling code and pipeline * add custom attention processor + custom activation + add to init * correct ProjectionModel forward * add stable audio to __initèè * add autoencoder and update pipeline and modeling code * add half Rope * add partial rotary v2 * add temporary modfis to scheduler * add EDM DPM Solver * remove TODOs * clean GLU * remove att.group_norm to attn processor * revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py * refactor GLU -> SwiGLU * remove redundant args * add channel multiples in autoencoder docstrings * changes in docsrtings and copyright headers * clean pipeline * further cleaning * remove peft and lora and fromoriginalmodel * Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace * make style * dummy models * fix copied from * add fast oobleck tests * add brownian tree * oobleck autoencoder slow tests * remove TODO * fast stable audio pipeline tests * add slow tests * make style * add first version of docs * wrap is_torchsde_available to the scheduler * fix slow test * test with input waveform * add input waveform * remove some todos * create stableaudio gaussian projection + make style * add pipeline to toctree * fix copied from * make quality * refactor timestep_features->time_proj * refactor joint_attention_kwargs->cross_attention_kwargs * remove forward_chunk * move StableAudioDitModel to transformers folder * correct convert + remove partial rotary embed * apply suggestions from yiyixuxu -> removing attn.kv_heads * remove temb * remove cross_attention_kwargs * further removal of cross_attention_kwargs * remove text encoder autocast to fp16 * continue removing autocast * make style * refactor how text and audio are embedded * add paper * update example code * make style * unify projection model forward + fix device placement * make style * remove fuse qkv * apply suggestions from review * Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py Co-authored-by: YiYi Xu <[email protected]> * make style * smaller models in fast tests * pass sequential offloading fast tests * add docs for vae and autoencoder * make style and update example * remove useless import * add cosine scheduler * dummy classes * cosine scheduler docs * better description of scheduler --------- Co-authored-by: YiYi Xu <[email protected]>
What does this PR do?
Stability AI recently open-sourced Stable Audio 1.0, which can be run using their toolkit library .
Contrarily to most diffusion models, the diffusion process here operates on a 1D latent signal, so I had to depart a bit from other models.
For now, I've drafted a bit how the pipeline will work, namely:
For this to work, I'm waiting for DAC to be integrated to transformers in this PR, in order to use the encoder and decoder code for the VAE.
Left TODO
cc @sayakpaul and @yiyixuxu !
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.