-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Stable Audio integration #8716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+3,771
−9
Merged
Stable Audio integration #8716
Changes from 41 commits
Commits
Show all changes
79 commits
Select commit
Hold shift + click to select a range
6151db5
WIP modeling code and pipeline
ylacombe 656561b
add custom attention processor + custom activation + add to init
ylacombe 819d746
correct ProjectionModel forward
ylacombe 8a1a9d8
add stable audio to __initèè
ylacombe 960339d
add autoencoder and update pipeline and modeling code
ylacombe 51c838f
add half Rope
ylacombe 87f1e26
add partial rotary v2
ylacombe 2f2bb8a
add temporary modfis to scheduler
ylacombe dc3f0eb
add EDM DPM Solver
ylacombe 07fc3c3
remove TODOs
ylacombe b49a3d5
clean GLU
ylacombe d1b3e20
remove att.group_norm to attn processor
ylacombe 23be1a3
revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
ylacombe 9d32408
refactor GLU -> SwiGLU
ylacombe 661d4f1
Merge branch 'main' into add-stable-audio
ylacombe 3689af0
remove redundant args
ylacombe 282e478
add channel multiples in autoencoder docstrings
ylacombe c9fef25
changes in docsrtings and copyright headers
ylacombe e51ffb2
clean pipeline
ylacombe ab6824c
further cleaning
ylacombe eeb19fe
remove peft and lora and fromoriginalmodel
ylacombe a43dfc5
Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace
ylacombe e7185e5
make style
ylacombe 3c6715e
dummy models
ylacombe 14fa2bf
fix copied from
ylacombe 21d0171
add fast oobleck tests
ylacombe 9cc7c02
add brownian tree
ylacombe c5eeafe
oobleck autoencoder slow tests
ylacombe 0a2d065
remove TODO
ylacombe 29e794b
fast stable audio pipeline tests
ylacombe 1bad287
add slow tests
ylacombe cf15409
make style
ylacombe dec61b3
add first version of docs
ylacombe 1961cc9
wrap is_torchsde_available to the scheduler
ylacombe 3c7df74
fix slow test
ylacombe 92392fd
test with input waveform
ylacombe d826f0f
add input waveform
ylacombe 94c2a25
remove some todos
ylacombe ad8660e
create stableaudio gaussian projection + make style
ylacombe 55b2a14
add pipeline to toctree
ylacombe 42a05c5
fix copied from
ylacombe 8919ba0
Merge branch 'huggingface:main' into add-stable-audio
ylacombe 2df8e41
make quality
ylacombe 68a5b56
refactor timestep_features->time_proj
ylacombe a81f46d
refactor joint_attention_kwargs->cross_attention_kwargs
ylacombe 8e910d3
remove forward_chunk
ylacombe 406f02a
move StableAudioDitModel to transformers folder
ylacombe 3a1dddb
correct convert + remove partial rotary embed
ylacombe c44d0a4
apply suggestions from yiyixuxu -> removing attn.kv_heads
ylacombe e5859f1
remove temb
ylacombe d35451d
remove cross_attention_kwargs
ylacombe 76debd5
further removal of cross_attention_kwargs
ylacombe acde6d5
remove text encoder autocast to fp16
ylacombe 566972d
continue removing autocast
ylacombe f187d65
make style
ylacombe af4f2ab
Merge branch 'huggingface:main' into add-stable-audio
ylacombe 8aa2e11
refactor how text and audio are embedded
ylacombe 58ca32c
add paper
ylacombe a4b6930
update example code
ylacombe c0873dc
make style
ylacombe bc36933
unify projection model forward + fix device placement
ylacombe f318e15
make style
ylacombe 8382156
remove fuse qkv
ylacombe 6ff9cf6
Merge branch 'huggingface:main' into add-stable-audio
ylacombe f91b084
apply suggestions from review
ylacombe 29dc552
Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
ylacombe ff62035
make style
ylacombe d61a1a9
smaller models in fast tests
ylacombe f1c9585
pass sequential offloading fast tests
ylacombe 8893373
add docs for vae and autoencoder
ylacombe 0b93804
Merge branch 'main' into add-stable-audio
ylacombe 264dd6d
make style and update example
ylacombe 0277c7f
remove useless import
ylacombe 1565d8a
add cosine scheduler
ylacombe d820e68
dummy classes
ylacombe fea9f8e
cosine scheduler docs
ylacombe 8abdb61
Merge branch 'main' into add-stable-audio
ylacombe 81dedd9
better description of scheduler
ylacombe 6d5d663
Merge branch 'huggingface:main' into add-stable-audio
ylacombe File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
<!--Copyright 2024 The HuggingFace Team. All rights reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# Stable Audio | ||
|
||
Stable Audio was proposed by Stability AI. it takes a text prompt as input and predicts the corresponding sound or music sample. | ||
|
||
Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder. | ||
|
||
Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT. | ||
|
||
This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). | ||
|
||
## Tips | ||
|
||
When constructing a prompt, keep in mind: | ||
|
||
* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno"). | ||
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality". | ||
|
||
During inference: | ||
|
||
* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference. | ||
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly. | ||
|
||
|
||
## StableAudioPipeline | ||
[[autodoc]] StableAudioPipeline | ||
- all | ||
- __call__ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,281 @@ | ||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. | ||
import argparse | ||
import json | ||
import os | ||
from contextlib import nullcontext | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from transformers import ( | ||
AutoTokenizer, | ||
T5EncoderModel, | ||
) | ||
|
||
from diffusers import ( | ||
AutoencoderOobleck, | ||
EDMDPMSolverMultistepScheduler, | ||
StableAudioDiTModel, | ||
StableAudioPipeline, | ||
StableAudioProjectionModel, | ||
) | ||
from diffusers.models.modeling_utils import load_model_dict_into_meta | ||
from diffusers.utils import is_accelerate_available | ||
|
||
|
||
if is_accelerate_available(): | ||
from accelerate import init_empty_weights | ||
|
||
|
||
def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5): | ||
projection_model_state_dict = { | ||
k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v | ||
for (k, v) in state_dict.items() | ||
if "conditioner.conditioners" in k | ||
} | ||
|
||
# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is. | ||
for key, value in list(projection_model_state_dict.items()): | ||
new_key = key.replace("seconds_start", "start_number_conditioner").replace( | ||
"seconds_total", "end_number_conditioner" | ||
) | ||
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key) | ||
|
||
model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k} | ||
for key, value in list(model_state_dict.items()): | ||
# attention layers | ||
new_key = ( | ||
key.replace("transformer.", "") | ||
.replace("layers", "transformer_blocks") | ||
.replace("self_attn", "attn1") | ||
.replace("cross_attn", "attn2") | ||
.replace("ff.ff", "ff.net") | ||
) | ||
new_key = ( | ||
new_key.replace("pre_norm", "norm1") | ||
.replace("cross_attend_norm", "norm2") | ||
.replace("ff_norm", "norm3") | ||
.replace("to_out", "to_out.0") | ||
) | ||
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm | ||
|
||
# other layers | ||
new_key = ( | ||
new_key.replace("project", "proj") | ||
.replace("to_timestep_embed", "timestep_proj") | ||
.replace("to_global_embed", "global_proj") | ||
.replace("to_cond_embed", "cross_attention_proj") | ||
) | ||
|
||
# we're using diffusers implementation of timestep_features (GaussianFourierProjection) which creates a 1D tensor | ||
if new_key == "timestep_features.weight": | ||
model_state_dict[key] = model_state_dict[key].squeeze(1) | ||
|
||
if "to_qkv" in new_key: | ||
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0) | ||
model_state_dict[new_key.replace("qkv", "q")] = q | ||
model_state_dict[new_key.replace("qkv", "k")] = k | ||
model_state_dict[new_key.replace("qkv", "v")] = v | ||
elif "to_kv" in new_key: | ||
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0) | ||
model_state_dict[new_key.replace("kv", "k")] = k | ||
model_state_dict[new_key.replace("kv", "v")] = v | ||
else: | ||
model_state_dict[new_key] = model_state_dict.pop(key) | ||
|
||
autoencoder_state_dict = { | ||
k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v | ||
for (k, v) in state_dict.items() | ||
if "pretransform.model." in k | ||
} | ||
|
||
for key, _ in list(autoencoder_state_dict.items()): | ||
new_key = key | ||
if "coder.layers" in new_key: | ||
# get idx of the layer | ||
idx = int(new_key.split("coder.layers.")[1].split(".")[0]) | ||
|
||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") | ||
|
||
if "encoder" in new_key: | ||
for i in range(3): | ||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") | ||
else: | ||
for i in range(2, 5): | ||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") | ||
|
||
new_key = new_key.replace("layers.0.beta", "snake1.beta") | ||
new_key = new_key.replace("layers.0.alpha", "snake1.alpha") | ||
new_key = new_key.replace("layers.2.beta", "snake2.beta") | ||
new_key = new_key.replace("layers.2.alpha", "snake2.alpha") | ||
new_key = new_key.replace("layers.1.bias", "conv1.bias") | ||
new_key = new_key.replace("layers.1.weight_", "conv1.weight_") | ||
new_key = new_key.replace("layers.3.bias", "conv2.bias") | ||
new_key = new_key.replace("layers.3.weight_", "conv2.weight_") | ||
|
||
if idx == num_autoencoder_layers + 1: | ||
new_key = new_key.replace(f"block.{idx-1}", "snake1") | ||
elif idx == num_autoencoder_layers + 2: | ||
new_key = new_key.replace(f"block.{idx-1}", "conv2") | ||
|
||
else: | ||
new_key = new_key | ||
|
||
value = autoencoder_state_dict.pop(key) | ||
if "snake" in new_key: | ||
value = value.unsqueeze(0).unsqueeze(-1) | ||
if new_key in autoencoder_state_dict: | ||
raise ValueError(f"{new_key} already in state dict.") | ||
autoencoder_state_dict[new_key] = value | ||
|
||
return model_state_dict, projection_model_state_dict, autoencoder_state_dict | ||
|
||
|
||
parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline") | ||
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config") | ||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") | ||
parser.add_argument( | ||
"--save_directory", | ||
type=str, | ||
default="./tmp/stable-audio-1.0", | ||
help="Directory to save a pipeline to. Will be created if it doesn't exist.", | ||
) | ||
parser.add_argument( | ||
"--repo_id", | ||
type=str, | ||
default="stable-audio-1.0", | ||
help="Hub organization to save the pipelines to", | ||
) | ||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") | ||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") | ||
|
||
args = parser.parse_args() | ||
|
||
checkpoint_path = ( | ||
os.path.join(args.model_folder_path, "model.safetensors") | ||
if args.use_safetensors | ||
else os.path.join(args.model_folder_path, "model.ckpt") | ||
) | ||
config_path = os.path.join(args.model_folder_path, "model_config.json") | ||
|
||
device = "cpu" | ||
if args.variant == "bf16": | ||
dtype = torch.bfloat16 | ||
else: | ||
dtype = torch.float32 | ||
|
||
with open(config_path) as f_in: | ||
config_dict = json.load(f_in) | ||
|
||
conditioning_dict = { | ||
conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"] | ||
} | ||
|
||
t5_model_config = conditioning_dict["prompt"] | ||
|
||
# T5 Text encoder | ||
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"]) | ||
tokenizer = AutoTokenizer.from_pretrained( | ||
t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"] | ||
) | ||
|
||
|
||
# scheduler | ||
scheduler = EDMDPMSolverMultistepScheduler( | ||
solver_order=2, | ||
prediction_type="v_prediction", | ||
noise_preconditioning_strategy="atan", | ||
sigma_data=1.0, | ||
algorithm_type="sde-dpmsolver++", | ||
sigma_schedule="exponential", | ||
noise_sampling_strategy="brownian_tree", | ||
) | ||
scheduler.config["sigma_min"] = 0.3 | ||
scheduler.config["sigma_max"] = 500 | ||
ctx = init_empty_weights if is_accelerate_available() else nullcontext | ||
|
||
|
||
if args.use_safetensors: | ||
orig_state_dict = load_file(checkpoint_path, device=device) | ||
else: | ||
orig_state_dict = torch.load(checkpoint_path, map_location=device) | ||
|
||
|
||
model_config = config_dict["model"]["diffusion"]["config"] | ||
|
||
model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers( | ||
orig_state_dict | ||
) | ||
|
||
|
||
with ctx(): | ||
projection_model = StableAudioProjectionModel( | ||
text_encoder_dim=text_encoder.config.d_model, | ||
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"], | ||
min_value=conditioning_dict["seconds_start"][ | ||
"min_val" | ||
], # assume `seconds_start` and `seconds_total` have the same min / max values. | ||
max_value=conditioning_dict["seconds_start"][ | ||
"max_val" | ||
], # assume `seconds_start` and `seconds_total` have the same min / max values. | ||
) | ||
if is_accelerate_available(): | ||
load_model_dict_into_meta(projection_model, projection_model_state_dict) | ||
else: | ||
projection_model.load_state_dict(projection_model_state_dict) | ||
|
||
attention_head_dim = model_config["embed_dim"] // model_config["num_heads"] | ||
with ctx(): | ||
model = StableAudioDiTModel( | ||
sample_size=int(config_dict["sample_size"]) | ||
/ int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]), | ||
in_channels=model_config["io_channels"], | ||
num_layers=model_config["depth"], | ||
attention_head_dim=attention_head_dim, | ||
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim, | ||
num_attention_heads=model_config["num_heads"], | ||
out_channels=model_config["io_channels"], | ||
cross_attention_dim=model_config["cond_token_dim"], | ||
timestep_features_dim=256, | ||
global_states_input_dim=model_config["global_cond_dim"], | ||
cross_attention_input_dim=model_config["cond_token_dim"], | ||
) | ||
if is_accelerate_available(): | ||
load_model_dict_into_meta(model, model_state_dict) | ||
else: | ||
model.load_state_dict(model_state_dict) | ||
|
||
|
||
autoencoder_config = config_dict["model"]["pretransform"]["config"] | ||
with ctx(): | ||
autoencoder = AutoencoderOobleck( | ||
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"], | ||
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"], | ||
decoder_channels=autoencoder_config["decoder"]["config"]["channels"], | ||
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"], | ||
audio_channels=autoencoder_config["io_channels"], | ||
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"], | ||
sampling_rate=config_dict["sample_rate"], | ||
) | ||
|
||
if is_accelerate_available(): | ||
load_model_dict_into_meta(autoencoder, autoencoder_state_dict) | ||
else: | ||
autoencoder.load_state_dict(autoencoder_state_dict) | ||
|
||
|
||
# Prior pipeline | ||
pipeline = StableAudioPipeline( | ||
transformer=model, | ||
tokenizer=tokenizer, | ||
text_encoder=text_encoder, | ||
scheduler=scheduler, | ||
vae=autoencoder, | ||
projection_model=projection_model, | ||
) | ||
pipeline.to(dtype).save_pretrained( | ||
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant | ||
) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🐍
nice param name