Skip to content

Stable Audio integration #8716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 79 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 41 commits
Commits
Show all changes
79 commits
Select commit Hold shift + click to select a range
6151db5
WIP modeling code and pipeline
ylacombe Jun 26, 2024
656561b
add custom attention processor + custom activation + add to init
ylacombe Jul 1, 2024
819d746
correct ProjectionModel forward
ylacombe Jul 2, 2024
8a1a9d8
add stable audio to __initèè
ylacombe Jul 9, 2024
960339d
add autoencoder and update pipeline and modeling code
ylacombe Jul 9, 2024
51c838f
add half Rope
ylacombe Jul 9, 2024
87f1e26
add partial rotary v2
ylacombe Jul 9, 2024
2f2bb8a
add temporary modfis to scheduler
ylacombe Jul 9, 2024
dc3f0eb
add EDM DPM Solver
ylacombe Jul 10, 2024
07fc3c3
remove TODOs
ylacombe Jul 10, 2024
b49a3d5
clean GLU
ylacombe Jul 10, 2024
d1b3e20
remove att.group_norm to attn processor
ylacombe Jul 10, 2024
23be1a3
revert back src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
ylacombe Jul 10, 2024
9d32408
refactor GLU -> SwiGLU
ylacombe Jul 15, 2024
661d4f1
Merge branch 'main' into add-stable-audio
ylacombe Jul 15, 2024
3689af0
remove redundant args
ylacombe Jul 15, 2024
282e478
add channel multiples in autoencoder docstrings
ylacombe Jul 15, 2024
c9fef25
changes in docsrtings and copyright headers
ylacombe Jul 15, 2024
e51ffb2
clean pipeline
ylacombe Jul 15, 2024
ab6824c
further cleaning
ylacombe Jul 15, 2024
eeb19fe
remove peft and lora and fromoriginalmodel
ylacombe Jul 15, 2024
a43dfc5
Delete src/diffusers/pipelines/stable_audio/diffusers.code-workspace
ylacombe Jul 15, 2024
e7185e5
make style
ylacombe Jul 15, 2024
3c6715e
dummy models
ylacombe Jul 15, 2024
14fa2bf
fix copied from
ylacombe Jul 15, 2024
21d0171
add fast oobleck tests
ylacombe Jul 15, 2024
9cc7c02
add brownian tree
ylacombe Jul 16, 2024
c5eeafe
oobleck autoencoder slow tests
ylacombe Jul 17, 2024
0a2d065
remove TODO
ylacombe Jul 17, 2024
29e794b
fast stable audio pipeline tests
ylacombe Jul 17, 2024
1bad287
add slow tests
ylacombe Jul 17, 2024
cf15409
make style
ylacombe Jul 17, 2024
dec61b3
add first version of docs
ylacombe Jul 17, 2024
1961cc9
wrap is_torchsde_available to the scheduler
ylacombe Jul 18, 2024
3c7df74
fix slow test
ylacombe Jul 18, 2024
92392fd
test with input waveform
ylacombe Jul 18, 2024
d826f0f
add input waveform
ylacombe Jul 18, 2024
94c2a25
remove some todos
ylacombe Jul 18, 2024
ad8660e
create stableaudio gaussian projection + make style
ylacombe Jul 18, 2024
55b2a14
add pipeline to toctree
ylacombe Jul 18, 2024
42a05c5
fix copied from
ylacombe Jul 18, 2024
8919ba0
Merge branch 'huggingface:main' into add-stable-audio
ylacombe Jul 18, 2024
2df8e41
make quality
ylacombe Jul 18, 2024
68a5b56
refactor timestep_features->time_proj
ylacombe Jul 24, 2024
a81f46d
refactor joint_attention_kwargs->cross_attention_kwargs
ylacombe Jul 24, 2024
8e910d3
remove forward_chunk
ylacombe Jul 24, 2024
406f02a
move StableAudioDitModel to transformers folder
ylacombe Jul 24, 2024
3a1dddb
correct convert + remove partial rotary embed
ylacombe Jul 24, 2024
c44d0a4
apply suggestions from yiyixuxu -> removing attn.kv_heads
ylacombe Jul 24, 2024
e5859f1
remove temb
ylacombe Jul 24, 2024
d35451d
remove cross_attention_kwargs
ylacombe Jul 24, 2024
76debd5
further removal of cross_attention_kwargs
ylacombe Jul 24, 2024
acde6d5
remove text encoder autocast to fp16
ylacombe Jul 24, 2024
566972d
continue removing autocast
ylacombe Jul 24, 2024
f187d65
make style
ylacombe Jul 24, 2024
af4f2ab
Merge branch 'huggingface:main' into add-stable-audio
ylacombe Jul 24, 2024
8aa2e11
refactor how text and audio are embedded
ylacombe Jul 24, 2024
58ca32c
add paper
ylacombe Jul 24, 2024
a4b6930
update example code
ylacombe Jul 24, 2024
c0873dc
make style
ylacombe Jul 24, 2024
bc36933
unify projection model forward + fix device placement
ylacombe Jul 25, 2024
f318e15
make style
ylacombe Jul 25, 2024
8382156
remove fuse qkv
ylacombe Jul 25, 2024
6ff9cf6
Merge branch 'huggingface:main' into add-stable-audio
ylacombe Jul 25, 2024
f91b084
apply suggestions from review
ylacombe Jul 25, 2024
29dc552
Update src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py
ylacombe Jul 26, 2024
ff62035
make style
ylacombe Jul 26, 2024
d61a1a9
smaller models in fast tests
ylacombe Jul 26, 2024
f1c9585
pass sequential offloading fast tests
ylacombe Jul 26, 2024
8893373
add docs for vae and autoencoder
ylacombe Jul 26, 2024
0b93804
Merge branch 'main' into add-stable-audio
ylacombe Jul 26, 2024
264dd6d
make style and update example
ylacombe Jul 26, 2024
0277c7f
remove useless import
ylacombe Jul 29, 2024
1565d8a
add cosine scheduler
ylacombe Jul 29, 2024
d820e68
dummy classes
ylacombe Jul 29, 2024
fea9f8e
cosine scheduler docs
ylacombe Jul 29, 2024
8abdb61
Merge branch 'main' into add-stable-audio
ylacombe Jul 29, 2024
81dedd9
better description of scheduler
ylacombe Jul 30, 2024
6d5d663
Merge branch 'huggingface:main' into add-stable-audio
ylacombe Jul 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,8 @@
title: Semantic Guidance
- local: api/pipelines/shap_e
title: Shap-E
- local: api/pipelines/stable_audio
title: Stable Audio
- local: api/pipelines/stable_cascade
title: Stable Cascade
- sections:
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/api/pipelines/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an
| [Semantic Guidance](semantic_stable_diffusion) | text2image |
| [Shap-E](shap_e) | text-to-3D, image-to-3D |
| [Spectrogram Diffusion](spectrogram_diffusion) | |
| [Stable Audio](stable_audio) | text2audio |
| [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution |
| [Stable Diffusion Model Editing](model_editing) | model editing |
| [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting |
Expand Down
39 changes: 39 additions & 0 deletions docs/source/en/api/pipelines/stable_audio.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Stable Audio

Stable Audio was proposed by Stability AI. it takes a text prompt as input and predicts the corresponding sound or music sample.

Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder.

Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT.

This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe).

## Tips

When constructing a prompt, keep in mind:

* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno").
* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality".

During inference:

* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference.
* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly.


## StableAudioPipeline
[[autodoc]] StableAudioPipeline
- all
- __call__
281 changes: 281 additions & 0 deletions scripts/convert_stable_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,281 @@
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline.
import argparse
import json
import os
from contextlib import nullcontext

import torch
from safetensors.torch import load_file
from transformers import (
AutoTokenizer,
T5EncoderModel,
)

from diffusers import (
AutoencoderOobleck,
EDMDPMSolverMultistepScheduler,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
)
from diffusers.models.modeling_utils import load_model_dict_into_meta
from diffusers.utils import is_accelerate_available


if is_accelerate_available():
from accelerate import init_empty_weights


def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5):
projection_model_state_dict = {
k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v
for (k, v) in state_dict.items()
if "conditioner.conditioners" in k
}

# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is.
for key, value in list(projection_model_state_dict.items()):
new_key = key.replace("seconds_start", "start_number_conditioner").replace(
"seconds_total", "end_number_conditioner"
)
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key)

model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k}
for key, value in list(model_state_dict.items()):
# attention layers
new_key = (
key.replace("transformer.", "")
.replace("layers", "transformer_blocks")
.replace("self_attn", "attn1")
.replace("cross_attn", "attn2")
.replace("ff.ff", "ff.net")
)
new_key = (
new_key.replace("pre_norm", "norm1")
.replace("cross_attend_norm", "norm2")
.replace("ff_norm", "norm3")
.replace("to_out", "to_out.0")
)
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm

# other layers
new_key = (
new_key.replace("project", "proj")
.replace("to_timestep_embed", "timestep_proj")
.replace("to_global_embed", "global_proj")
.replace("to_cond_embed", "cross_attention_proj")
)

# we're using diffusers implementation of timestep_features (GaussianFourierProjection) which creates a 1D tensor
if new_key == "timestep_features.weight":
model_state_dict[key] = model_state_dict[key].squeeze(1)

if "to_qkv" in new_key:
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0)
model_state_dict[new_key.replace("qkv", "q")] = q
model_state_dict[new_key.replace("qkv", "k")] = k
model_state_dict[new_key.replace("qkv", "v")] = v
elif "to_kv" in new_key:
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0)
model_state_dict[new_key.replace("kv", "k")] = k
model_state_dict[new_key.replace("kv", "v")] = v
else:
model_state_dict[new_key] = model_state_dict.pop(key)

autoencoder_state_dict = {
k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v
for (k, v) in state_dict.items()
if "pretransform.model." in k
}

for key, _ in list(autoencoder_state_dict.items()):
new_key = key
if "coder.layers" in new_key:
# get idx of the layer
idx = int(new_key.split("coder.layers.")[1].split(".")[0])

new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")

if "encoder" in new_key:
for i in range(3):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
else:
for i in range(2, 5):
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")

new_key = new_key.replace("layers.0.beta", "snake1.beta")
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
new_key = new_key.replace("layers.2.beta", "snake2.beta")
new_key = new_key.replace("layers.2.alpha", "snake2.alpha")
new_key = new_key.replace("layers.1.bias", "conv1.bias")
new_key = new_key.replace("layers.1.weight_", "conv1.weight_")
new_key = new_key.replace("layers.3.bias", "conv2.bias")
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")

if idx == num_autoencoder_layers + 1:
new_key = new_key.replace(f"block.{idx-1}", "snake1")
elif idx == num_autoencoder_layers + 2:
new_key = new_key.replace(f"block.{idx-1}", "conv2")

else:
new_key = new_key

value = autoencoder_state_dict.pop(key)
if "snake" in new_key:
value = value.unsqueeze(0).unsqueeze(-1)
Comment on lines +129 to +130
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🐍

nice param name

if new_key in autoencoder_state_dict:
raise ValueError(f"{new_key} already in state dict.")
autoencoder_state_dict[new_key] = value

return model_state_dict, projection_model_state_dict, autoencoder_state_dict


parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline")
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config")
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion")
parser.add_argument(
"--save_directory",
type=str,
default="./tmp/stable-audio-1.0",
help="Directory to save a pipeline to. Will be created if it doesn't exist.",
)
parser.add_argument(
"--repo_id",
type=str,
default="stable-audio-1.0",
help="Hub organization to save the pipelines to",
)
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub")
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights")

args = parser.parse_args()

checkpoint_path = (
os.path.join(args.model_folder_path, "model.safetensors")
if args.use_safetensors
else os.path.join(args.model_folder_path, "model.ckpt")
)
config_path = os.path.join(args.model_folder_path, "model_config.json")

device = "cpu"
if args.variant == "bf16":
dtype = torch.bfloat16
else:
dtype = torch.float32

with open(config_path) as f_in:
config_dict = json.load(f_in)

conditioning_dict = {
conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"]
}

t5_model_config = conditioning_dict["prompt"]

# T5 Text encoder
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"])
tokenizer = AutoTokenizer.from_pretrained(
t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"]
)


# scheduler
scheduler = EDMDPMSolverMultistepScheduler(
solver_order=2,
prediction_type="v_prediction",
noise_preconditioning_strategy="atan",
sigma_data=1.0,
algorithm_type="sde-dpmsolver++",
sigma_schedule="exponential",
noise_sampling_strategy="brownian_tree",
)
scheduler.config["sigma_min"] = 0.3
scheduler.config["sigma_max"] = 500
ctx = init_empty_weights if is_accelerate_available() else nullcontext


if args.use_safetensors:
orig_state_dict = load_file(checkpoint_path, device=device)
else:
orig_state_dict = torch.load(checkpoint_path, map_location=device)


model_config = config_dict["model"]["diffusion"]["config"]

model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers(
orig_state_dict
)


with ctx():
projection_model = StableAudioProjectionModel(
text_encoder_dim=text_encoder.config.d_model,
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"],
min_value=conditioning_dict["seconds_start"][
"min_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
max_value=conditioning_dict["seconds_start"][
"max_val"
], # assume `seconds_start` and `seconds_total` have the same min / max values.
)
if is_accelerate_available():
load_model_dict_into_meta(projection_model, projection_model_state_dict)
else:
projection_model.load_state_dict(projection_model_state_dict)

attention_head_dim = model_config["embed_dim"] // model_config["num_heads"]
with ctx():
model = StableAudioDiTModel(
sample_size=int(config_dict["sample_size"])
/ int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]),
in_channels=model_config["io_channels"],
num_layers=model_config["depth"],
attention_head_dim=attention_head_dim,
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim,
num_attention_heads=model_config["num_heads"],
out_channels=model_config["io_channels"],
cross_attention_dim=model_config["cond_token_dim"],
timestep_features_dim=256,
global_states_input_dim=model_config["global_cond_dim"],
cross_attention_input_dim=model_config["cond_token_dim"],
)
if is_accelerate_available():
load_model_dict_into_meta(model, model_state_dict)
else:
model.load_state_dict(model_state_dict)


autoencoder_config = config_dict["model"]["pretransform"]["config"]
with ctx():
autoencoder = AutoencoderOobleck(
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"],
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"],
decoder_channels=autoencoder_config["decoder"]["config"]["channels"],
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"],
audio_channels=autoencoder_config["io_channels"],
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"],
sampling_rate=config_dict["sample_rate"],
)

if is_accelerate_available():
load_model_dict_into_meta(autoencoder, autoencoder_state_dict)
else:
autoencoder.load_state_dict(autoencoder_state_dict)


# Prior pipeline
pipeline = StableAudioPipeline(
transformer=model,
tokenizer=tokenizer,
text_encoder=text_encoder,
scheduler=scheduler,
vae=autoencoder,
projection_model=projection_model,
)
pipeline.to(dtype).save_pretrained(
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant
)
8 changes: 8 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"AuraFlowTransformer2DModel",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderOobleck",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
Expand Down Expand Up @@ -291,6 +292,9 @@
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
"StableAudioDiTModel",
"StableAudioPipeline",
"StableAudioProjectionModel",
"StableCascadeCombinedPipeline",
"StableCascadeDecoderPipeline",
"StableCascadePriorPipeline",
Expand Down Expand Up @@ -513,6 +517,7 @@
AuraFlowTransformer2DModel,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Expand Down Expand Up @@ -703,6 +708,9 @@
SemanticStableDiffusionPipeline,
ShapEImg2ImgPipeline,
ShapEPipeline,
StableAudioDiTModel,
StableAudioPipeline,
StableAudioProjectionModel,
StableCascadeCombinedPipeline,
StableCascadeDecoderPipeline,
StableCascadePriorPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
_import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"]
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["autoencoders.vq_model"] = ["VQModel"]
Expand Down Expand Up @@ -74,6 +75,7 @@
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderOobleck,
AutoencoderTiny,
ConsistencyDecoderVAE,
VQModel,
Expand Down
Loading
Loading