-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Stable Audio integration #8716
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable Audio integration #8716
Changes from 8 commits
6151db5
656561b
819d746
8a1a9d8
960339d
51c838f
87f1e26
2f2bb8a
dc3f0eb
07fc3c3
b49a3d5
d1b3e20
23be1a3
9d32408
661d4f1
3689af0
282e478
c9fef25
e51ffb2
ab6824c
eeb19fe
a43dfc5
e7185e5
3c6715e
14fa2bf
21d0171
9cc7c02
c5eeafe
0a2d065
29e794b
1bad287
cf15409
dec61b3
1961cc9
3c7df74
92392fd
d826f0f
94c2a25
ad8660e
55b2a14
42a05c5
8919ba0
2df8e41
68a5b56
a81f46d
8e910d3
406f02a
3a1dddb
c44d0a4
e5859f1
d35451d
76debd5
acde6d5
566972d
f187d65
af4f2ab
8aa2e11
58ca32c
a4b6930
c0873dc
bc36933
f318e15
8382156
6ff9cf6
f91b084
29dc552
ff62035
d61a1a9
f1c9585
8893373
0b93804
264dd6d
0277c7f
1565d8a
d820e68
fea9f8e
8abdb61
81dedd9
6d5d663
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,244 @@ | ||
# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. | ||
import argparse | ||
import os | ||
from contextlib import nullcontext | ||
import json | ||
|
||
import torch | ||
from safetensors.torch import load_file | ||
from transformers import ( | ||
AutoTokenizer, | ||
T5EncoderModel, | ||
) | ||
from diffusers import ( | ||
AutoencoderOobleck, | ||
DPMSolverMultistepScheduler, | ||
StableAudioPipeline, | ||
StableAudioDiTModel, | ||
StableAudioProjectionModel, | ||
) | ||
|
||
from diffusers.models.modeling_utils import load_model_dict_into_meta | ||
from diffusers.utils import is_accelerate_available | ||
|
||
|
||
if is_accelerate_available(): | ||
from accelerate import init_empty_weights | ||
|
||
|
||
def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5): | ||
projection_model_state_dict = {k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding") :v for (k,v) in state_dict.items() if "conditioner.conditioners" in k} | ||
|
||
# NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is. | ||
for key, value in list(projection_model_state_dict.items()): | ||
new_key = key.replace("seconds_start", "start_number_conditioner").replace("seconds_total", "end_number_conditioner") | ||
projection_model_state_dict[new_key] = projection_model_state_dict.pop(key) | ||
|
||
|
||
model_state_dict = {k.replace("model.model.", "") :v for (k,v) in state_dict.items() if "model.model." in k} | ||
for key, value in list(model_state_dict.items()): | ||
# attention layers | ||
new_key = key.replace("transformer.", "").replace("layers", "transformer_blocks").replace("self_attn", "attn1").replace("cross_attn", "attn2").replace("ff.ff", "ff.net") | ||
new_key = new_key.replace("pre_norm", "norm1").replace("cross_attend_norm", "norm2").replace("ff_norm", "norm3").replace("to_out", "to_out.0") | ||
new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm | ||
|
||
# other layers | ||
new_key = new_key.replace("project", "proj").replace("to_timestep_embed", "timestep_proj").replace("to_global_embed", "global_proj").replace("to_cond_embed", "cross_attention_proj") | ||
|
||
# TODO: (YL) as compared to stable audio model weights we'rte missing `rotary_pos_emb.inv_freq`, we probably don't need it but to verify | ||
|
||
# we're using diffusers implementation of timestep_features (GaussianFourierProjection) which creates a 1D tensor | ||
if new_key == "timestep_features.weight": | ||
model_state_dict[key] = model_state_dict[key].squeeze(1) | ||
|
||
|
||
if "to_qkv" in new_key: | ||
q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0) | ||
model_state_dict[new_key.replace("qkv", "q")] = q | ||
model_state_dict[new_key.replace("qkv", "k")] = k | ||
model_state_dict[new_key.replace("qkv", "v")] = v | ||
elif "to_kv" in new_key: | ||
k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0) | ||
model_state_dict[new_key.replace("kv", "k")] = k | ||
model_state_dict[new_key.replace("kv", "v")] = v | ||
else: | ||
model_state_dict[new_key] = model_state_dict.pop(key) | ||
|
||
autoencoder_state_dict = {k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1") :v for (k,v) in state_dict.items() if "pretransform.model." in k} | ||
|
||
for key, _ in list(autoencoder_state_dict.items()): | ||
new_key = key | ||
if "coder.layers" in new_key: | ||
# get idx of the layer | ||
idx = int(new_key.split("coder.layers.")[1].split(".")[0]) | ||
|
||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") | ||
|
||
if "encoder" in new_key: | ||
for i in range(3): | ||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") | ||
else: | ||
for i in range(2,5): | ||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") | ||
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") | ||
|
||
new_key = new_key.replace("layers.0.beta", "snake1.beta") | ||
new_key = new_key.replace("layers.0.alpha", "snake1.alpha") | ||
new_key = new_key.replace("layers.2.beta", "snake2.beta") | ||
new_key = new_key.replace("layers.2.alpha", "snake2.alpha") | ||
new_key = new_key.replace("layers.1.bias", "conv1.bias") | ||
new_key = new_key.replace("layers.1.weight_", "conv1.weight_") | ||
new_key = new_key.replace("layers.3.bias", "conv2.bias") | ||
new_key = new_key.replace("layers.3.weight_", "conv2.weight_") | ||
|
||
if idx == num_autoencoder_layers + 1: | ||
new_key = new_key.replace(f"block.{idx-1}", "snake1") | ||
elif idx == num_autoencoder_layers + 2: | ||
new_key = new_key.replace(f"block.{idx-1}", "conv2") | ||
|
||
else: | ||
new_key = new_key | ||
|
||
value = autoencoder_state_dict.pop(key) | ||
if "snake" in new_key: | ||
value = value.unsqueeze(0).unsqueeze(-1) | ||
if new_key in autoencoder_state_dict: | ||
raise ValueError(f"{new_key} already in state dict.") | ||
autoencoder_state_dict[new_key] = value | ||
|
||
return model_state_dict, projection_model_state_dict, autoencoder_state_dict | ||
|
||
parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline") | ||
parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config") | ||
parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") | ||
parser.add_argument( | ||
"--save_directory", | ||
type=str, | ||
default="./tmp/stable-audio-1.0", | ||
help="Directory to save a pipeline to. Will be created if it doesn't exist.", | ||
) | ||
parser.add_argument( | ||
"--repo_id", | ||
type=str, | ||
default="stable-audio-1.0", | ||
help="Hub organization to save the pipelines to", | ||
) | ||
parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") | ||
parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") | ||
|
||
args = parser.parse_args() | ||
|
||
checkpoint_path = os.path.join(args.model_folder_path, "model.safetensors") if args.use_safetensors else os.path.join(args.model_folder_path, "model.ckpt") | ||
config_path = os.path.join(args.model_folder_path, "model_config.json") | ||
|
||
device = "cpu" | ||
if args.variant == "bf16": | ||
dtype = torch.bfloat16 | ||
else: | ||
dtype = torch.float32 | ||
|
||
with open(config_path) as f_in: | ||
config_dict = json.load(f_in) | ||
|
||
conditioning_dict = {conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"]} | ||
|
||
t5_model_config = conditioning_dict["prompt"] | ||
|
||
# T5 Text encoder | ||
text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"]) | ||
tokenizer = AutoTokenizer.from_pretrained(t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"]) | ||
|
||
|
||
# scheduler | ||
scheduler = DPMSolverMultistepScheduler(solver_order=2, algorithm_type="sde-dpmsolver++", use_exponential_sigmas=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still need to find the right scheduler for this! |
||
scheduler.config["sigma_min"] = 0.3 | ||
scheduler.config["sigma_max"] = 500 | ||
ctx = init_empty_weights if is_accelerate_available() else nullcontext | ||
|
||
|
||
if args.use_safetensors: | ||
orig_state_dict = load_file(checkpoint_path, device=device) | ||
else: | ||
orig_state_dict = torch.load(checkpoint_path, map_location=device) | ||
|
||
|
||
model_config = config_dict["model"]["diffusion"]["config"] | ||
|
||
model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers(orig_state_dict) | ||
|
||
|
||
with ctx(): | ||
projection_model = StableAudioProjectionModel( | ||
text_encoder_dim=text_encoder.config.d_model, | ||
conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"], | ||
min_value=conditioning_dict["seconds_start"]["min_val"], # assume `seconds_start` and `seconds_total` have the same min / max values. | ||
max_value=conditioning_dict["seconds_start"]["max_val"], # assume `seconds_start` and `seconds_total` have the same min / max values. | ||
) | ||
if is_accelerate_available(): | ||
load_model_dict_into_meta(projection_model, projection_model_state_dict) | ||
else: | ||
projection_model.load_state_dict(projection_model_state_dict) | ||
|
||
attention_head_dim = model_config["embed_dim"] // model_config["num_heads"] | ||
with ctx(): | ||
model = StableAudioDiTModel( | ||
sample_size=int(config_dict["sample_size"])/int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]), | ||
in_channels=model_config["io_channels"], | ||
num_layers=model_config["depth"], | ||
attention_head_dim=attention_head_dim, | ||
num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim, | ||
num_attention_heads=model_config["num_heads"], | ||
out_channels=model_config["io_channels"], | ||
cross_attention_dim=model_config["cond_token_dim"], | ||
timestep_features_dim=256, | ||
global_states_input_dim=model_config["global_cond_dim"], | ||
cross_attention_input_dim=model_config["cond_token_dim"], | ||
) | ||
if is_accelerate_available(): | ||
load_model_dict_into_meta(model, model_state_dict) | ||
else: | ||
model.load_state_dict(model_state_dict) | ||
|
||
|
||
autoencoder_config = config_dict["model"]["pretransform"]["config"] | ||
with ctx(): | ||
autoencoder = AutoencoderOobleck( | ||
encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"], | ||
downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"], | ||
decoder_channels=autoencoder_config["decoder"]["config"]["channels"], | ||
decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"], | ||
audio_channels=autoencoder_config["io_channels"], | ||
channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"], | ||
sampling_rate=config_dict["sample_rate"], | ||
) | ||
|
||
if is_accelerate_available(): | ||
load_model_dict_into_meta(autoencoder, autoencoder_state_dict) | ||
else: | ||
autoencoder.load_state_dict(autoencoder_state_dict) | ||
|
||
|
||
|
||
|
||
# Prior pipeline | ||
pipeline = StableAudioPipeline( | ||
transformer=model, | ||
tokenizer=tokenizer, | ||
text_encoder=text_encoder, | ||
scheduler=scheduler, | ||
vae=autoencoder, | ||
projection_model=projection_model, | ||
|
||
) | ||
pipeline.to(dtype).save_pretrained( | ||
args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant | ||
) | ||
|
||
|
||
# TODO (YL): remove | ||
pipeline.to(dtype).save_pretrained( | ||
args.save_directory, push_to_hub=False, variant=args.variant | ||
) |
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I had to add a new variant of GEGLU here |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -122,6 +122,31 @@ def forward(self, hidden_states, *args, **kwargs): | |
hidden_states, gate = hidden_states.chunk(2, dim=-1) | ||
return hidden_states * self.gelu(gate) | ||
|
||
class GLU(nn.Module): | ||
ylacombe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
r""" | ||
A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. | ||
It's similar to `GEGLU` but uses SiLU / Swish instead of GeLU. | ||
|
||
Parameters: | ||
dim_in (`int`): The number of channels in the input. | ||
dim_out (`int`): The number of channels in the output. | ||
act_fn (str): Name of activation function used. | ||
bias (`bool`, defaults to True): Whether to use a bias in the linear layer. | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still using this in the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure. Since that Feedforward is very focused on GELU and GeGELU I thought it might be a better option to have its own class. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it is very reasonable to use this in FeedForward, no need to move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The feedforward does not have GeLU though. |
||
|
||
def __init__(self, dim_in: int, dim_out: int, bias: bool = True): | ||
super().__init__() | ||
self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) | ||
self.activation = nn.SiLU() | ||
|
||
def forward(self, hidden_states, *args, **kwargs): | ||
if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
deprecate("scale", "1.0.0", deprecation_message) | ||
ylacombe marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden_states = self.proj(hidden_states) | ||
hidden_states, gate = hidden_states.chunk(2, dim=-1) | ||
return hidden_states * self.activation(gate) | ||
|
||
|
||
class ApproximateGELU(nn.Module): | ||
r""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🐍
nice param name