diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 81d7a95454d7..bce17b291478 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -239,6 +239,8 @@ title: AsymmetricAutoencoderKL - local: api/models/autoencoder_tiny title: Tiny AutoEncoder + - local: api/models/autoencoder_oobleck + title: Oobleck AutoEncoder - local: api/models/consistency_decoder_vae title: ConsistencyDecoderVAE - local: api/models/transformer2d @@ -259,6 +261,8 @@ title: TransformerTemporalModel - local: api/models/sd3_transformer2d title: SD3Transformer2DModel + - local: api/models/stable_audio_transformer + title: StableAudioDiTModel - local: api/models/prior_transformer title: PriorTransformer - local: api/models/controlnet @@ -362,6 +366,8 @@ title: Semantic Guidance - local: api/pipelines/shap_e title: Shap-E + - local: api/pipelines/stable_audio + title: Stable Audio - local: api/pipelines/stable_cascade title: Stable Cascade - sections: @@ -425,6 +431,8 @@ title: CMStochasticIterativeScheduler - local: api/schedulers/consistency_decoder title: ConsistencyDecoderScheduler + - local: api/schedulers/cosine_dpm + title: CosineDPMSolverMultistepScheduler - local: api/schedulers/ddim_inverse title: DDIMInverseScheduler - local: api/schedulers/ddim diff --git a/docs/source/en/api/models/autoencoder_oobleck.md b/docs/source/en/api/models/autoencoder_oobleck.md new file mode 100644 index 000000000000..bbc00e048b64 --- /dev/null +++ b/docs/source/en/api/models/autoencoder_oobleck.md @@ -0,0 +1,38 @@ + + +# AutoencoderOobleck + +The Oobleck variational autoencoder (VAE) model with KL loss was introduced in [Stability-AI/stable-audio-tools](https://github.com/Stability-AI/stable-audio-tools) and [Stable Audio Open](https://huggingface.co/papers/2407.14358) by Stability AI. The model is used in 🤗 Diffusers to encode audio waveforms into latents and to decode latent representations into audio waveforms. + +The abstract from the paper is: + +*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.* + +## AutoencoderOobleck + +[[autodoc]] AutoencoderOobleck + - decode + - encode + - all + +## OobleckDecoderOutput + +[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput + +## OobleckDecoderOutput + +[[autodoc]] models.autoencoders.autoencoder_oobleck.OobleckDecoderOutput + +## AutoencoderOobleckOutput + +[[autodoc]] models.autoencoders.autoencoder_oobleck.AutoencoderOobleckOutput diff --git a/docs/source/en/api/models/stable_audio_transformer.md b/docs/source/en/api/models/stable_audio_transformer.md new file mode 100644 index 000000000000..396b96c8c710 --- /dev/null +++ b/docs/source/en/api/models/stable_audio_transformer.md @@ -0,0 +1,19 @@ + + +# StableAudioDiTModel + +A Transformer model for audio waveforms from [Stable Audio Open](https://huggingface.co/papers/2407.14358). + +## StableAudioDiTModel + +[[autodoc]] StableAudioDiTModel diff --git a/docs/source/en/api/pipelines/overview.md b/docs/source/en/api/pipelines/overview.md index e7b8bf4936c0..bb4dd57fd132 100644 --- a/docs/source/en/api/pipelines/overview.md +++ b/docs/source/en/api/pipelines/overview.md @@ -71,6 +71,7 @@ The table below lists all the pipelines currently available in 🤗 Diffusers an | [Semantic Guidance](semantic_stable_diffusion) | text2image | | [Shap-E](shap_e) | text-to-3D, image-to-3D | | [Spectrogram Diffusion](spectrogram_diffusion) | | +| [Stable Audio](stable_audio) | text2audio | | [Stable Diffusion](stable_diffusion/overview) | text2image, image2image, depth2image, inpainting, image variation, latent upscaler, super-resolution | | [Stable Diffusion Model Editing](model_editing) | model editing | | [Stable Diffusion XL](stable_diffusion/stable_diffusion_xl) | text2image, image2image, inpainting | diff --git a/docs/source/en/api/pipelines/stable_audio.md b/docs/source/en/api/pipelines/stable_audio.md new file mode 100644 index 000000000000..3e7b2857e4eb --- /dev/null +++ b/docs/source/en/api/pipelines/stable_audio.md @@ -0,0 +1,42 @@ + + +# Stable Audio + +Stable Audio was proposed in [Stable Audio Open](https://arxiv.org/abs/2407.14358) by Zach Evans et al. . it takes a text prompt as input and predicts the corresponding sound or music sample. + +Stable Audio Open generates variable-length (up to 47s) stereo audio at 44.1kHz from text prompts. It comprises three components: an autoencoder that compresses waveforms into a manageable sequence length, a T5-based text embedding for text conditioning, and a transformer-based diffusion (DiT) model that operates in the latent space of the autoencoder. + +Stable Audio is trained on a corpus of around 48k audio recordings, where around 47k are from Freesound and the rest are from the Free Music Archive (FMA). All audio files are licensed under CC0, CC BY, or CC Sampling+. This data is used to train the autoencoder and the DiT. + +The abstract of the paper is the following: +*Open generative models are vitally important for the community, allowing for fine-tunes and serving as baselines when presenting new models. However, most current text-to-audio models are private and not accessible for artists and researchers to build upon. Here we describe the architecture and training process of a new open-weights text-to-audio model trained with Creative Commons data. Our evaluation shows that the model's performance is competitive with the state-of-the-art across various metrics. Notably, the reported FDopenl3 results (measuring the realism of the generations) showcase its potential for high-quality stereo sound synthesis at 44.1kHz.* + +This pipeline was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). The original codebase can be found at [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool). + +## Tips + +When constructing a prompt, keep in mind: + +* Descriptive prompt inputs work best; use adjectives to describe the sound (for example, "high quality" or "clear") and make the prompt context specific where possible (e.g. "melodic techno with a fast beat and synths" works better than "techno"). +* Using a *negative prompt* can significantly improve the quality of the generated audio. Try using a negative prompt of "low quality, average quality". + +During inference: + +* The _quality_ of the generated audio sample can be controlled by the `num_inference_steps` argument; higher steps give higher quality audio at the expense of slower inference. +* Multiple waveforms can be generated in one go: set `num_waveforms_per_prompt` to a value greater than 1 to enable. Automatic scoring will be performed between the generated waveforms and prompt text, and the audios ranked from best to worst accordingly. + + +## StableAudioPipeline +[[autodoc]] StableAudioPipeline + - all + - __call__ diff --git a/docs/source/en/api/schedulers/cosine_dpm.md b/docs/source/en/api/schedulers/cosine_dpm.md new file mode 100644 index 000000000000..7685269c2145 --- /dev/null +++ b/docs/source/en/api/schedulers/cosine_dpm.md @@ -0,0 +1,24 @@ + + +# CosineDPMSolverMultistepScheduler + +The [`CosineDPMSolverMultistepScheduler`] is a variant of [`DPMSolverMultistepScheduler`] with cosine schedule, proposed by Nichol and Dhariwal (2021). +It is being used in the [Stable Audio Open](https://arxiv.org/abs/2407.14358) paper and the [Stability-AI/stable-audio-tool](https://github.com/Stability-AI/stable-audio-tool) codebase. + +This scheduler was contributed by [Yoach Lacombe](https://huggingface.co/ylacombe). + +## CosineDPMSolverMultistepScheduler +[[autodoc]] CosineDPMSolverMultistepScheduler + +## SchedulerOutput +[[autodoc]] schedulers.scheduling_utils.SchedulerOutput diff --git a/scripts/convert_stable_audio.py b/scripts/convert_stable_audio.py new file mode 100644 index 000000000000..a0f9d0f87d90 --- /dev/null +++ b/scripts/convert_stable_audio.py @@ -0,0 +1,279 @@ +# Run this script to convert the Stable Cascade model weights to a diffusers pipeline. +import argparse +import json +import os +from contextlib import nullcontext + +import torch +from safetensors.torch import load_file +from transformers import ( + AutoTokenizer, + T5EncoderModel, +) + +from diffusers import ( + AutoencoderOobleck, + CosineDPMSolverMultistepScheduler, + StableAudioDiTModel, + StableAudioPipeline, + StableAudioProjectionModel, +) +from diffusers.models.modeling_utils import load_model_dict_into_meta +from diffusers.utils import is_accelerate_available + + +if is_accelerate_available(): + from accelerate import init_empty_weights + + +def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_layers=5): + projection_model_state_dict = { + k.replace("conditioner.conditioners.", "").replace("embedder.embedding", "time_positional_embedding"): v + for (k, v) in state_dict.items() + if "conditioner.conditioners" in k + } + + # NOTE: we assume here that there's no projection layer from the text encoder to the latent space, script should be adapted a bit if there is. + for key, value in list(projection_model_state_dict.items()): + new_key = key.replace("seconds_start", "start_number_conditioner").replace( + "seconds_total", "end_number_conditioner" + ) + projection_model_state_dict[new_key] = projection_model_state_dict.pop(key) + + model_state_dict = {k.replace("model.model.", ""): v for (k, v) in state_dict.items() if "model.model." in k} + for key, value in list(model_state_dict.items()): + # attention layers + new_key = ( + key.replace("transformer.", "") + .replace("layers", "transformer_blocks") + .replace("self_attn", "attn1") + .replace("cross_attn", "attn2") + .replace("ff.ff", "ff.net") + ) + new_key = ( + new_key.replace("pre_norm", "norm1") + .replace("cross_attend_norm", "norm2") + .replace("ff_norm", "norm3") + .replace("to_out", "to_out.0") + ) + new_key = new_key.replace("gamma", "weight").replace("beta", "bias") # replace layernorm + + # other layers + new_key = ( + new_key.replace("project", "proj") + .replace("to_timestep_embed", "timestep_proj") + .replace("timestep_features", "time_proj") + .replace("to_global_embed", "global_proj") + .replace("to_cond_embed", "cross_attention_proj") + ) + + # we're using diffusers implementation of time_proj (GaussianFourierProjection) which creates a 1D tensor + if new_key == "time_proj.weight": + model_state_dict[key] = model_state_dict[key].squeeze(1) + + if "to_qkv" in new_key: + q, k, v = torch.chunk(model_state_dict.pop(key), 3, dim=0) + model_state_dict[new_key.replace("qkv", "q")] = q + model_state_dict[new_key.replace("qkv", "k")] = k + model_state_dict[new_key.replace("qkv", "v")] = v + elif "to_kv" in new_key: + k, v = torch.chunk(model_state_dict.pop(key), 2, dim=0) + model_state_dict[new_key.replace("kv", "k")] = k + model_state_dict[new_key.replace("kv", "v")] = v + else: + model_state_dict[new_key] = model_state_dict.pop(key) + + autoencoder_state_dict = { + k.replace("pretransform.model.", "").replace("coder.layers.0", "coder.conv1"): v + for (k, v) in state_dict.items() + if "pretransform.model." in k + } + + for key, _ in list(autoencoder_state_dict.items()): + new_key = key + if "coder.layers" in new_key: + # get idx of the layer + idx = int(new_key.split("coder.layers.")[1].split(".")[0]) + + new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}") + + if "encoder" in new_key: + for i in range(3): + new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}") + new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1") + new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1") + else: + for i in range(2, 5): + new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}") + new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1") + new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1") + + new_key = new_key.replace("layers.0.beta", "snake1.beta") + new_key = new_key.replace("layers.0.alpha", "snake1.alpha") + new_key = new_key.replace("layers.2.beta", "snake2.beta") + new_key = new_key.replace("layers.2.alpha", "snake2.alpha") + new_key = new_key.replace("layers.1.bias", "conv1.bias") + new_key = new_key.replace("layers.1.weight_", "conv1.weight_") + new_key = new_key.replace("layers.3.bias", "conv2.bias") + new_key = new_key.replace("layers.3.weight_", "conv2.weight_") + + if idx == num_autoencoder_layers + 1: + new_key = new_key.replace(f"block.{idx-1}", "snake1") + elif idx == num_autoencoder_layers + 2: + new_key = new_key.replace(f"block.{idx-1}", "conv2") + + else: + new_key = new_key + + value = autoencoder_state_dict.pop(key) + if "snake" in new_key: + value = value.unsqueeze(0).unsqueeze(-1) + if new_key in autoencoder_state_dict: + raise ValueError(f"{new_key} already in state dict.") + autoencoder_state_dict[new_key] = value + + return model_state_dict, projection_model_state_dict, autoencoder_state_dict + + +parser = argparse.ArgumentParser(description="Convert Stable Audio 1.0 model weights to a diffusers pipeline") +parser.add_argument("--model_folder_path", type=str, help="Location of Stable Audio weights and config") +parser.add_argument("--use_safetensors", action="store_true", help="Use SafeTensors for conversion") +parser.add_argument( + "--save_directory", + type=str, + default="./tmp/stable-audio-1.0", + help="Directory to save a pipeline to. Will be created if it doesn't exist.", +) +parser.add_argument( + "--repo_id", + type=str, + default="stable-audio-1.0", + help="Hub organization to save the pipelines to", +) +parser.add_argument("--push_to_hub", action="store_true", help="Push to hub") +parser.add_argument("--variant", type=str, help="Set to bf16 to save bfloat16 weights") + +args = parser.parse_args() + +checkpoint_path = ( + os.path.join(args.model_folder_path, "model.safetensors") + if args.use_safetensors + else os.path.join(args.model_folder_path, "model.ckpt") +) +config_path = os.path.join(args.model_folder_path, "model_config.json") + +device = "cpu" +if args.variant == "bf16": + dtype = torch.bfloat16 +else: + dtype = torch.float32 + +with open(config_path) as f_in: + config_dict = json.load(f_in) + +conditioning_dict = { + conditioning["id"]: conditioning["config"] for conditioning in config_dict["model"]["conditioning"]["configs"] +} + +t5_model_config = conditioning_dict["prompt"] + +# T5 Text encoder +text_encoder = T5EncoderModel.from_pretrained(t5_model_config["t5_model_name"]) +tokenizer = AutoTokenizer.from_pretrained( + t5_model_config["t5_model_name"], truncation=True, model_max_length=t5_model_config["max_length"] +) + + +# scheduler +scheduler = CosineDPMSolverMultistepScheduler( + sigma_min=0.3, + sigma_max=500, + solver_order=2, + prediction_type="v_prediction", + sigma_data=1.0, + sigma_schedule="exponential", +) +ctx = init_empty_weights if is_accelerate_available() else nullcontext + + +if args.use_safetensors: + orig_state_dict = load_file(checkpoint_path, device=device) +else: + orig_state_dict = torch.load(checkpoint_path, map_location=device) + + +model_config = config_dict["model"]["diffusion"]["config"] + +model_state_dict, projection_model_state_dict, autoencoder_state_dict = convert_stable_audio_state_dict_to_diffusers( + orig_state_dict +) + + +with ctx(): + projection_model = StableAudioProjectionModel( + text_encoder_dim=text_encoder.config.d_model, + conditioning_dim=config_dict["model"]["conditioning"]["cond_dim"], + min_value=conditioning_dict["seconds_start"][ + "min_val" + ], # assume `seconds_start` and `seconds_total` have the same min / max values. + max_value=conditioning_dict["seconds_start"][ + "max_val" + ], # assume `seconds_start` and `seconds_total` have the same min / max values. + ) +if is_accelerate_available(): + load_model_dict_into_meta(projection_model, projection_model_state_dict) +else: + projection_model.load_state_dict(projection_model_state_dict) + +attention_head_dim = model_config["embed_dim"] // model_config["num_heads"] +with ctx(): + model = StableAudioDiTModel( + sample_size=int(config_dict["sample_size"]) + / int(config_dict["model"]["pretransform"]["config"]["downsampling_ratio"]), + in_channels=model_config["io_channels"], + num_layers=model_config["depth"], + attention_head_dim=attention_head_dim, + num_key_value_attention_heads=model_config["cond_token_dim"] // attention_head_dim, + num_attention_heads=model_config["num_heads"], + out_channels=model_config["io_channels"], + cross_attention_dim=model_config["cond_token_dim"], + time_proj_dim=256, + global_states_input_dim=model_config["global_cond_dim"], + cross_attention_input_dim=model_config["cond_token_dim"], + ) +if is_accelerate_available(): + load_model_dict_into_meta(model, model_state_dict) +else: + model.load_state_dict(model_state_dict) + + +autoencoder_config = config_dict["model"]["pretransform"]["config"] +with ctx(): + autoencoder = AutoencoderOobleck( + encoder_hidden_size=autoencoder_config["encoder"]["config"]["channels"], + downsampling_ratios=autoencoder_config["encoder"]["config"]["strides"], + decoder_channels=autoencoder_config["decoder"]["config"]["channels"], + decoder_input_channels=autoencoder_config["decoder"]["config"]["latent_dim"], + audio_channels=autoencoder_config["io_channels"], + channel_multiples=autoencoder_config["encoder"]["config"]["c_mults"], + sampling_rate=config_dict["sample_rate"], + ) + +if is_accelerate_available(): + load_model_dict_into_meta(autoencoder, autoencoder_state_dict) +else: + autoencoder.load_state_dict(autoencoder_state_dict) + + +# Prior pipeline +pipeline = StableAudioPipeline( + transformer=model, + tokenizer=tokenizer, + text_encoder=text_encoder, + scheduler=scheduler, + vae=autoencoder, + projection_model=projection_model, +) +pipeline.to(dtype).save_pretrained( + args.save_directory, repo_id=args.repo_id, push_to_hub=args.push_to_hub, variant=args.variant +) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index f42ccc064624..10bda1316bd7 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -79,6 +79,7 @@ "AuraFlowTransformer2DModel", "AutoencoderKL", "AutoencoderKLTemporalDecoder", + "AutoencoderOobleck", "AutoencoderTiny", "ConsistencyDecoderVAE", "ControlNetModel", @@ -100,6 +101,7 @@ "SD3MultiControlNetModel", "SD3Transformer2DModel", "SparseControlNetModel", + "StableAudioDiTModel", "StableCascadeUNet", "T2IAdapter", "T5FilmDecoder", @@ -210,7 +212,7 @@ ] else: - _import_structure["schedulers"].extend(["DPMSolverSDEScheduler"]) + _import_structure["schedulers"].extend(["CosineDPMSolverMultistepScheduler", "DPMSolverSDEScheduler"]) try: if not (is_torch_available() and is_transformers_available()): @@ -293,6 +295,8 @@ "SemanticStableDiffusionPipeline", "ShapEImg2ImgPipeline", "ShapEPipeline", + "StableAudioPipeline", + "StableAudioProjectionModel", "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", "StableCascadePriorPipeline", @@ -515,6 +519,7 @@ AuraFlowTransformer2DModel, AutoencoderKL, AutoencoderKLTemporalDecoder, + AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, ControlNetModel, @@ -536,6 +541,7 @@ SD3MultiControlNetModel, SD3Transformer2DModel, SparseControlNetModel, + StableAudioDiTModel, T2IAdapter, T5FilmDecoder, Transformer2DModel, @@ -632,7 +638,7 @@ except OptionalDependencyNotAvailable: from .utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: - from .schedulers import DPMSolverSDEScheduler + from .schedulers import CosineDPMSolverMultistepScheduler, DPMSolverSDEScheduler try: if not (is_torch_available() and is_transformers_available()): @@ -707,6 +713,8 @@ SemanticStableDiffusionPipeline, ShapEImg2ImgPipeline, ShapEPipeline, + StableAudioPipeline, + StableAudioProjectionModel, StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, StableCascadePriorPipeline, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index d35786ee7642..76fe2b682a46 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -29,6 +29,7 @@ _import_structure["autoencoders.autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"] _import_structure["autoencoders.autoencoder_kl"] = ["AutoencoderKL"] _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"] + _import_structure["autoencoders.autoencoder_oobleck"] = ["AutoencoderOobleck"] _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"] _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"] _import_structure["autoencoders.vq_model"] = ["VQModel"] @@ -47,6 +48,7 @@ _import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"] _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"] _import_structure["transformers.prior_transformer"] = ["PriorTransformer"] + _import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"] _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"] _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"] _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"] @@ -75,6 +77,7 @@ AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderKLTemporalDecoder, + AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, VQModel, @@ -96,6 +99,7 @@ PixArtTransformer2DModel, PriorTransformer, SD3Transformer2DModel, + StableAudioDiTModel, T5FilmDecoder, Transformer2DModel, TransformerTemporalModel, diff --git a/src/diffusers/models/activations.py b/src/diffusers/models/activations.py index 28ee92ddb2e3..fb24a36bae75 100644 --- a/src/diffusers/models/activations.py +++ b/src/diffusers/models/activations.py @@ -123,6 +123,28 @@ def forward(self, hidden_states, *args, **kwargs): return hidden_states * self.gelu(gate) +class SwiGLU(nn.Module): + r""" + A [variant](https://arxiv.org/abs/2002.05202) of the gated linear unit activation function. It's similar to `GEGLU` + but uses SiLU / Swish instead of GeLU. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.activation = nn.SiLU() + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.activation(gate) + + class ApproximateGELU(nn.Module): r""" The approximate form of the Gaussian Error Linear Unit (GELU). For more details, see section 2 of this diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index f6969470c36e..b204770e6d37 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -19,7 +19,7 @@ from ..utils import deprecate, logging from ..utils.torch_utils import maybe_allow_in_graph -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU from .attention_processor import Attention, JointAttnProcessor2_0 from .embeddings import SinusoidalPositionalEmbedding from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm @@ -820,6 +820,8 @@ def __init__( act_fn = GEGLU(dim, inner_dim, bias=bias) elif activation_fn == "geglu-approximate": act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "swiglu": + act_fn = SwiGLU(dim, inner_dim, bias=bias) self.net = nn.ModuleList([]) # project in diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6669222c695d..5c5464c37683 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect import math -from typing import Callable, List, Optional, Union +from typing import Callable, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -49,6 +49,10 @@ class Attention(nn.Module): The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + kv_heads (`int`, *optional*, defaults to `None`): + The number of key and value heads to use for multi-head attention. Defaults to `heads`. If + `kv_heads=heads`, the model will use Multi Head Attention (MHA), if `kv_heads=1` the model will use Multi + Query Attention (MQA) otherwise GQA is used. dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. dropout (`float`, *optional*, defaults to 0.0): @@ -1624,6 +1628,137 @@ def __call__( return hidden_states +class StableAudioAttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is + used in the Stable Audio model. It applies rotary embedding on query and key vector, and allows MHA, GQA or MQA. + """ + + def __init__(self): + if not hasattr(F, "scaled_dot_product_attention"): + raise ImportError( + "StableAudioAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." + ) + + def apply_partial_rotary_emb( + self, + x: torch.Tensor, + freqs_cis: Tuple[torch.Tensor], + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + rot_dim = freqs_cis[0].shape[-1] + x_to_rotate, x_unrotated = x[..., :rot_dim], x[..., rot_dim:] + + x_rotated = apply_rotary_emb(x_to_rotate, freqs_cis, use_real=True, use_real_unbind_dim=-2) + + out = torch.cat((x_rotated, x_unrotated), dim=-1) + return out + + def __call__( + self, + attn: Attention, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + rotary_emb: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + from .embeddings import apply_rotary_emb + + residual = hidden_states + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + ) + + if attention_mask is not None: + attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + head_dim = query.shape[-1] // attn.heads + kv_heads = key.shape[-1] // head_dim + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + key = key.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, kv_heads, head_dim).transpose(1, 2) + + if kv_heads != attn.heads: + # if GQA or MQA, repeat the key/value heads to reach the number of query heads. + heads_per_kv_head = attn.heads // kv_heads + key = torch.repeat_interleave(key, heads_per_kv_head, dim=1) + value = torch.repeat_interleave(value, heads_per_kv_head, dim=1) + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # Apply RoPE if needed + if rotary_emb is not None: + query_dtype = query.dtype + key_dtype = key.dtype + query = query.to(torch.float32) + key = key.to(torch.float32) + + rot_dim = rotary_emb[0].shape[-1] + query_to_rotate, query_unrotated = query[..., :rot_dim], query[..., rot_dim:] + query_rotated = apply_rotary_emb(query_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + query = torch.cat((query_rotated, query_unrotated), dim=-1) + + if not attn.is_cross_attention: + key_to_rotate, key_unrotated = key[..., :rot_dim], key[..., rot_dim:] + key_rotated = apply_rotary_emb(key_to_rotate, rotary_emb, use_real=True, use_real_unbind_dim=-2) + + key = torch.cat((key_rotated, key_unrotated), dim=-1) + + query = query.to(query_dtype) + key = key.to(key_dtype) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + # TODO: add support for attn.scale when we move to Torch 2.1 + hidden_states = F.scaled_dot_product_attention( + query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False + ) + + hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) + hidden_states = hidden_states.to(query.dtype) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + class HunyuanAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/src/diffusers/models/autoencoders/__init__.py b/src/diffusers/models/autoencoders/__init__.py index 5c47748d62e0..885007b54ea1 100644 --- a/src/diffusers/models/autoencoders/__init__.py +++ b/src/diffusers/models/autoencoders/__init__.py @@ -1,6 +1,7 @@ from .autoencoder_asym_kl import AsymmetricAutoencoderKL from .autoencoder_kl import AutoencoderKL from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder +from .autoencoder_oobleck import AutoencoderOobleck from .autoencoder_tiny import AutoencoderTiny from .consistency_decoder_vae import ConsistencyDecoderVAE from .vq_model import VQModel diff --git a/src/diffusers/models/autoencoders/autoencoder_oobleck.py b/src/diffusers/models/autoencoders/autoencoder_oobleck.py new file mode 100644 index 000000000000..e8e372a709d7 --- /dev/null +++ b/src/diffusers/models/autoencoders/autoencoder_oobleck.py @@ -0,0 +1,464 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +from torch.nn.utils import weight_norm + +from ...configuration_utils import ConfigMixin, register_to_config +from ...utils import BaseOutput +from ...utils.accelerate_utils import apply_forward_hook +from ...utils.torch_utils import randn_tensor +from ..modeling_utils import ModelMixin + + +class Snake1d(nn.Module): + """ + A 1-dimensional Snake activation function module. + """ + + def __init__(self, hidden_dim, logscale=True): + super().__init__() + self.alpha = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + self.beta = nn.Parameter(torch.zeros(1, hidden_dim, 1)) + + self.alpha.requires_grad = True + self.beta.requires_grad = True + self.logscale = logscale + + def forward(self, hidden_states): + shape = hidden_states.shape + + alpha = self.alpha if not self.logscale else torch.exp(self.alpha) + beta = self.beta if not self.logscale else torch.exp(self.beta) + + hidden_states = hidden_states.reshape(shape[0], shape[1], -1) + hidden_states = hidden_states + (beta + 1e-9).reciprocal() * torch.sin(alpha * hidden_states).pow(2) + hidden_states = hidden_states.reshape(shape) + return hidden_states + + +class OobleckResidualUnit(nn.Module): + """ + A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations. + """ + + def __init__(self, dimension: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + + self.snake1 = Snake1d(dimension) + self.conv1 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)) + self.snake2 = Snake1d(dimension) + self.conv2 = weight_norm(nn.Conv1d(dimension, dimension, kernel_size=1)) + + def forward(self, hidden_state): + """ + Forward pass through the residual unit. + + Args: + hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`): + Input tensor . + + Returns: + output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`) + Input tensor after passing through the residual unit. + """ + output_tensor = hidden_state + output_tensor = self.conv1(self.snake1(output_tensor)) + output_tensor = self.conv2(self.snake2(output_tensor)) + + padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2 + if padding > 0: + hidden_state = hidden_state[..., padding:-padding] + output_tensor = hidden_state + output_tensor + return output_tensor + + +class OobleckEncoderBlock(nn.Module): + """Encoder block used in Oobleck encoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.res_unit1 = OobleckResidualUnit(input_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(input_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(input_dim, dilation=9) + self.snake1 = Snake1d(input_dim) + self.conv1 = weight_norm( + nn.Conv1d(input_dim, output_dim, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)) + ) + + def forward(self, hidden_state): + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.snake1(self.res_unit3(hidden_state)) + hidden_state = self.conv1(hidden_state) + + return hidden_state + + +class OobleckDecoderBlock(nn.Module): + """Decoder block used in Oobleck decoder.""" + + def __init__(self, input_dim, output_dim, stride: int = 1): + super().__init__() + + self.snake1 = Snake1d(input_dim) + self.conv_t1 = weight_norm( + nn.ConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ) + ) + self.res_unit1 = OobleckResidualUnit(output_dim, dilation=1) + self.res_unit2 = OobleckResidualUnit(output_dim, dilation=3) + self.res_unit3 = OobleckResidualUnit(output_dim, dilation=9) + + def forward(self, hidden_state): + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv_t1(hidden_state) + hidden_state = self.res_unit1(hidden_state) + hidden_state = self.res_unit2(hidden_state) + hidden_state = self.res_unit3(hidden_state) + + return hidden_state + + +class OobleckDiagonalGaussianDistribution(object): + def __init__(self, parameters: torch.Tensor, deterministic: bool = False): + self.parameters = parameters + self.mean, self.scale = parameters.chunk(2, dim=1) + self.std = nn.functional.softplus(self.scale) + 1e-4 + self.var = self.std * self.std + self.logvar = torch.log(self.var) + self.deterministic = deterministic + + def sample(self, generator: Optional[torch.Generator] = None) -> torch.Tensor: + # make sure sample is on the same device as the parameters and has same dtype + sample = randn_tensor( + self.mean.shape, + generator=generator, + device=self.parameters.device, + dtype=self.parameters.dtype, + ) + x = self.mean + self.std * sample + return x + + def kl(self, other: "OobleckDiagonalGaussianDistribution" = None) -> torch.Tensor: + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return (self.mean * self.mean + self.var - self.logvar - 1.0).sum(1).mean() + else: + normalized_diff = torch.pow(self.mean - other.mean, 2) / other.var + var_ratio = self.var / other.var + logvar_diff = self.logvar - other.logvar + + kl = normalized_diff + var_ratio + logvar_diff - 1 + + kl = kl.sum(1).mean() + return kl + + def mode(self) -> torch.Tensor: + return self.mean + + +@dataclass +class AutoencoderOobleckOutput(BaseOutput): + """ + Output of AutoencoderOobleck encoding method. + + Args: + latent_dist (`OobleckDiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and standard deviation of + `OobleckDiagonalGaussianDistribution`. `OobleckDiagonalGaussianDistribution` allows for sampling latents + from the distribution. + """ + + latent_dist: "OobleckDiagonalGaussianDistribution" # noqa: F821 + + +@dataclass +class OobleckDecoderOutput(BaseOutput): + r""" + Output of decoding method. + + Args: + sample (`torch.Tensor` of shape `(batch_size, audio_channels, sequence_length)`): + The decoded output sample from the last layer of the model. + """ + + sample: torch.Tensor + + +class OobleckEncoder(nn.Module): + """Oobleck Encoder""" + + def __init__(self, encoder_hidden_size, audio_channels, downsampling_ratios, channel_multiples): + super().__init__() + + strides = downsampling_ratios + channel_multiples = [1] + channel_multiples + + # Create first convolution + self.conv1 = weight_norm(nn.Conv1d(audio_channels, encoder_hidden_size, kernel_size=7, padding=3)) + + self.block = [] + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride_index, stride in enumerate(strides): + self.block += [ + OobleckEncoderBlock( + input_dim=encoder_hidden_size * channel_multiples[stride_index], + output_dim=encoder_hidden_size * channel_multiples[stride_index + 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(self.block) + d_model = encoder_hidden_size * channel_multiples[-1] + self.snake1 = Snake1d(d_model) + self.conv2 = weight_norm(nn.Conv1d(d_model, encoder_hidden_size, kernel_size=3, padding=1)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for module in self.block: + hidden_state = module(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class OobleckDecoder(nn.Module): + """Oobleck Decoder""" + + def __init__(self, channels, input_channels, audio_channels, upsampling_ratios, channel_multiples): + super().__init__() + + strides = upsampling_ratios + channel_multiples = [1] + channel_multiples + + # Add first conv layer + self.conv1 = weight_norm(nn.Conv1d(input_channels, channels * channel_multiples[-1], kernel_size=7, padding=3)) + + # Add upsampling + MRF blocks + block = [] + for stride_index, stride in enumerate(strides): + block += [ + OobleckDecoderBlock( + input_dim=channels * channel_multiples[len(strides) - stride_index], + output_dim=channels * channel_multiples[len(strides) - stride_index - 1], + stride=stride, + ) + ] + + self.block = nn.ModuleList(block) + output_dim = channels + self.snake1 = Snake1d(output_dim) + self.conv2 = weight_norm(nn.Conv1d(channels, audio_channels, kernel_size=7, padding=3, bias=False)) + + def forward(self, hidden_state): + hidden_state = self.conv1(hidden_state) + + for layer in self.block: + hidden_state = layer(hidden_state) + + hidden_state = self.snake1(hidden_state) + hidden_state = self.conv2(hidden_state) + + return hidden_state + + +class AutoencoderOobleck(ModelMixin, ConfigMixin): + r""" + An autoencoder for encoding waveforms into latents and decoding latent representations into waveforms. First + introduced in Stable Audio. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + encoder_hidden_size (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the encoder. + downsampling_ratios (`List[int]`, *optional*, defaults to `[2, 4, 4, 8, 8]`): + Ratios for downsampling in the encoder. These are used in reverse order for upsampling in the decoder. + channel_multiples (`List[int]`, *optional*, defaults to `[1, 2, 4, 8, 16]`): + Multiples used to determine the hidden sizes of the hidden layers. + decoder_channels (`int`, *optional*, defaults to 128): + Intermediate representation dimension for the decoder. + decoder_input_channels (`int`, *optional*, defaults to 64): + Input dimension for the decoder. Corresponds to the latent dimension. + audio_channels (`int`, *optional*, defaults to 2): + Number of channels in the audio data. Either 1 for mono or 2 for stereo. + sampling_rate (`int`, *optional*, defaults to 44100): + The sampling rate at which the audio waveform should be digitalized expressed in hertz (Hz). + """ + + _supports_gradient_checkpointing = False + + @register_to_config + def __init__( + self, + encoder_hidden_size=128, + downsampling_ratios=[2, 4, 4, 8, 8], + channel_multiples=[1, 2, 4, 8, 16], + decoder_channels=128, + decoder_input_channels=64, + audio_channels=2, + sampling_rate=44100, + ): + super().__init__() + + self.encoder_hidden_size = encoder_hidden_size + self.downsampling_ratios = downsampling_ratios + self.decoder_channels = decoder_channels + self.upsampling_ratios = downsampling_ratios[::-1] + self.hop_length = int(np.prod(downsampling_ratios)) + self.sampling_rate = sampling_rate + + self.encoder = OobleckEncoder( + encoder_hidden_size=encoder_hidden_size, + audio_channels=audio_channels, + downsampling_ratios=downsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.decoder = OobleckDecoder( + channels=decoder_channels, + input_channels=decoder_input_channels, + audio_channels=audio_channels, + upsampling_ratios=self.upsampling_ratios, + channel_multiples=channel_multiples, + ) + + self.use_slicing = False + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode( + self, x: torch.Tensor, return_dict: bool = True + ) -> Union[AutoencoderOobleckOutput, Tuple[OobleckDiagonalGaussianDistribution]]: + """ + Encode a batch of images into latents. + + Args: + x (`torch.Tensor`): Input batch of images. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple. + + Returns: + The latent representations of the encoded images. If `return_dict` is True, a + [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. + """ + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self.encoder(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self.encoder(x) + + posterior = OobleckDiagonalGaussianDistribution(h) + + if not return_dict: + return (posterior,) + + return AutoencoderOobleckOutput(latent_dist=posterior) + + def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[OobleckDecoderOutput, torch.Tensor]: + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) + + @apply_forward_hook + def decode( + self, z: torch.FloatTensor, return_dict: bool = True, generator=None + ) -> Union[OobleckDecoderOutput, torch.FloatTensor]: + """ + Decode a batch of images. + + Args: + z (`torch.Tensor`): Input batch of latent vectors. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`~models.vae.OobleckDecoderOutput`] instead of a plain tuple. + + Returns: + [`~models.vae.OobleckDecoderOutput`] or `tuple`: + If return_dict is True, a [`~models.vae.OobleckDecoderOutput`] is returned, otherwise a plain `tuple` + is returned. + + """ + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return OobleckDecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.Tensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[OobleckDecoderOutput, torch.Tensor]: + r""" + Args: + sample (`torch.Tensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`OobleckDecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return OobleckDecoderOutput(sample=dec) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 7684fdf9cd6c..71e301d0d707 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -352,7 +352,13 @@ def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, n def get_1d_rotary_pos_embed( - dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0 + dim: int, + pos: Union[np.ndarray, int], + theta: float = 10000.0, + use_real=False, + linear_factor=1.0, + ntk_factor=1.0, + repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. @@ -372,6 +378,9 @@ def get_1d_rotary_pos_embed( Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. + repeat_interleave_real (`bool`, *optional*, defaults to `True`): + If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. + Otherwise, they are concateanted with themselves. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ @@ -383,10 +392,14 @@ def get_1d_rotary_pos_embed( freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] - if use_real: + if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin + elif use_real: + freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] + freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] + return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis @@ -396,6 +409,7 @@ def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, + use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings @@ -417,8 +431,17 @@ def apply_rotary_emb( sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) - x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] - x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + if use_real_unbind_dim == -1: + # Use for example in Lumina + x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] + x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) + elif use_real_unbind_dim == -2: + # Use for example in Stable Audio + x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] + x_rotated = torch.cat([-x_imag, x_real], dim=-1) + else: + raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") + out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out diff --git a/src/diffusers/models/transformers/__init__.py b/src/diffusers/models/transformers/__init__.py index ae5103160790..8d4b8d9d6ecb 100644 --- a/src/diffusers/models/transformers/__init__.py +++ b/src/diffusers/models/transformers/__init__.py @@ -10,6 +10,7 @@ from .lumina_nextdit2d import LuminaNextDiT2DModel from .pixart_transformer_2d import PixArtTransformer2DModel from .prior_transformer import PriorTransformer + from .stable_audio_transformer import StableAudioDiTModel from .t5_film_transformer import T5FilmDecoder from .transformer_2d import Transformer2DModel from .transformer_sd3 import SD3Transformer2DModel diff --git a/src/diffusers/models/transformers/stable_audio_transformer.py b/src/diffusers/models/transformers/stable_audio_transformer.py new file mode 100644 index 000000000000..e3462b51a412 --- /dev/null +++ b/src/diffusers/models/transformers/stable_audio_transformer.py @@ -0,0 +1,458 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from typing import Any, Dict, Optional, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.attention import FeedForward +from ...models.attention_processor import ( + Attention, + AttentionProcessor, + StableAudioAttnProcessor2_0, +) +from ...models.modeling_utils import ModelMixin +from ...models.transformers.transformer_2d import Transformer2DModelOutput +from ...utils import is_torch_version, logging +from ...utils.torch_utils import maybe_allow_in_graph + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioGaussianFourierProjection(nn.Module): + """Gaussian Fourier embeddings for noise levels.""" + + # Copied from diffusers.models.embeddings.GaussianFourierProjection.__init__ + def __init__( + self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False + ): + super().__init__() + self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.log = log + self.flip_sin_to_cos = flip_sin_to_cos + + if set_W_to_weight: + # to delete later + del self.weight + self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) + self.weight = self.W + del self.W + + def forward(self, x): + if self.log: + x = torch.log(x) + + x_proj = 2 * np.pi * x[:, None] @ self.weight[None, :] + + if self.flip_sin_to_cos: + out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) + else: + out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) + return out + + +@maybe_allow_in_graph +class StableAudioDiTBlock(nn.Module): + r""" + Transformer block used in Stable Audio model (https://github.com/Stability-AI/stable-audio-tools). Allow skip + connection and QKNorm + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for the query states. + num_key_value_attention_heads (`int`): The number of heads to use for the key and value states. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + num_key_value_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + upcast_attention: bool = False, + norm_eps: float = 1e-5, + ff_inner_dim: Optional[int] = None, + ): + super().__init__() + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = nn.LayerNorm(dim, elementwise_affine=True, eps=norm_eps) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) + + # 2. Cross-Attn + self.norm2 = nn.LayerNorm(dim, norm_eps, True) + + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + kv_heads=num_key_value_attention_heads, + dropout=dropout, + bias=False, + upcast_attention=upcast_attention, + out_bias=False, + processor=StableAudioAttnProcessor2_0(), + ) # is self-attn if encoder_hidden_states is none + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim, norm_eps, True) + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn="swiglu", + final_dropout=False, + inner_dim=ff_inner_dim, + bias=True, + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + encoder_hidden_states: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + rotary_embedding: Optional[torch.FloatTensor] = None, + ) -> torch.Tensor: + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + norm_hidden_states = self.norm1(hidden_states) + + attn_output = self.attn1( + norm_hidden_states, + attention_mask=attention_mask, + rotary_emb=rotary_embedding, + ) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention + norm_hidden_states = self.norm2(hidden_states) + + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + ) + hidden_states = attn_output + hidden_states + + # 3. Feed-forward + norm_hidden_states = self.norm3(hidden_states) + ff_output = self.ff(norm_hidden_states) + + hidden_states = ff_output + hidden_states + + return hidden_states + + +class StableAudioDiTModel(ModelMixin, ConfigMixin): + """ + The Diffusion Transformer model introduced in Stable Audio. + + Reference: https://github.com/Stability-AI/stable-audio-tools + + Parameters: + sample_size ( `int`, *optional*, defaults to 1024): The size of the input sample. + in_channels (`int`, *optional*, defaults to 64): The number of channels in the input. + num_layers (`int`, *optional*, defaults to 24): The number of layers of Transformer blocks to use. + attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head. + num_attention_heads (`int`, *optional*, defaults to 24): The number of heads to use for the query states. + num_key_value_attention_heads (`int`, *optional*, defaults to 12): + The number of heads to use for the key and value states. + out_channels (`int`, defaults to 64): Number of output channels. + cross_attention_dim ( `int`, *optional*, defaults to 768): Dimension of the cross-attention projection. + time_proj_dim ( `int`, *optional*, defaults to 256): Dimension of the timestep inner projection. + global_states_input_dim ( `int`, *optional*, defaults to 1536): + Input dimension of the global hidden states projection. + cross_attention_input_dim ( `int`, *optional*, defaults to 768): + Input dimension of the cross-attention projection + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: int = 1024, + in_channels: int = 64, + num_layers: int = 24, + attention_head_dim: int = 64, + num_attention_heads: int = 24, + num_key_value_attention_heads: int = 12, + out_channels: int = 64, + cross_attention_dim: int = 768, + time_proj_dim: int = 256, + global_states_input_dim: int = 1536, + cross_attention_input_dim: int = 768, + ): + super().__init__() + self.sample_size = sample_size + self.out_channels = out_channels + self.inner_dim = num_attention_heads * attention_head_dim + + self.time_proj = StableAudioGaussianFourierProjection( + embedding_size=time_proj_dim // 2, + flip_sin_to_cos=True, + log=False, + set_W_to_weight=False, + ) + + self.timestep_proj = nn.Sequential( + nn.Linear(time_proj_dim, self.inner_dim, bias=True), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=True), + ) + + self.global_proj = nn.Sequential( + nn.Linear(global_states_input_dim, self.inner_dim, bias=False), + nn.SiLU(), + nn.Linear(self.inner_dim, self.inner_dim, bias=False), + ) + + self.cross_attention_proj = nn.Sequential( + nn.Linear(cross_attention_input_dim, cross_attention_dim, bias=False), + nn.SiLU(), + nn.Linear(cross_attention_dim, cross_attention_dim, bias=False), + ) + + self.preprocess_conv = nn.Conv1d(in_channels, in_channels, 1, bias=False) + self.proj_in = nn.Linear(in_channels, self.inner_dim, bias=False) + + self.transformer_blocks = nn.ModuleList( + [ + StableAudioDiTBlock( + dim=self.inner_dim, + num_attention_heads=num_attention_heads, + num_key_value_attention_heads=num_key_value_attention_heads, + attention_head_dim=attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for i in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=False) + self.postprocess_conv = nn.Conv1d(self.out_channels, self.out_channels, 1, bias=False) + + self.gradient_checkpointing = False + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor() + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.transformers.hunyuan_transformer_2d.HunyuanDiT2DModel.set_default_attn_processor with Hunyuan->StableAudio + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + self.set_attn_processor(StableAudioAttnProcessor2_0()) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: torch.LongTensor = None, + encoder_hidden_states: torch.FloatTensor = None, + global_hidden_states: torch.FloatTensor = None, + rotary_embedding: torch.FloatTensor = None, + return_dict: bool = True, + attention_mask: Optional[torch.LongTensor] = None, + encoder_attention_mask: Optional[torch.LongTensor] = None, + ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: + """ + The [`StableAudioDiTModel`] forward method. + + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, in_channels, sequence_len)`): + Input `hidden_states`. + timestep ( `torch.LongTensor`): + Used to indicate denoising step. + encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, encoder_sequence_len, cross_attention_input_dim)`): + Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. + global_hidden_states (`torch.FloatTensor` of shape `(batch size, global_sequence_len, global_states_input_dim)`): + Global embeddings that will be prepended to the hidden states. + rotary_embedding (`torch.Tensor`): + The rotary embeddings to apply on query and key tensors during attention calculation. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain + tuple. + attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token indices, formed by concatenating the attention + masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + encoder_attention_mask (`torch.Tensor` of shape `(batch_size, sequence_len)`, *optional*): + Mask to avoid performing attention on padding token cross-attention indices, formed by concatenating + the attention masks + for the two text encoders together. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + cross_attention_hidden_states = self.cross_attention_proj(encoder_hidden_states) + global_hidden_states = self.global_proj(global_hidden_states) + time_hidden_states = self.timestep_proj(self.time_proj(timestep.to(self.dtype))) + + global_hidden_states = global_hidden_states + time_hidden_states.unsqueeze(1) + + hidden_states = self.preprocess_conv(hidden_states) + hidden_states + # (batch_size, dim, sequence_length) -> (batch_size, sequence_length, dim) + hidden_states = hidden_states.transpose(1, 2) + + hidden_states = self.proj_in(hidden_states) + + # prepend global states to hidden states + hidden_states = torch.cat([global_hidden_states, hidden_states], dim=-2) + if attention_mask is not None: + prepend_mask = torch.ones((hidden_states.shape[0], 1), device=hidden_states.device, dtype=torch.bool) + attention_mask = torch.cat([prepend_mask, attention_mask], dim=-1) + + for block in self.transformer_blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + attention_mask, + cross_attention_hidden_states, + encoder_attention_mask, + rotary_embedding, + **ckpt_kwargs, + ) + + else: + hidden_states = block( + hidden_states=hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=cross_attention_hidden_states, + encoder_attention_mask=encoder_attention_mask, + rotary_embedding=rotary_embedding, + ) + + hidden_states = self.proj_out(hidden_states) + + # (batch_size, sequence_length, dim) -> (batch_size, dim, sequence_length) + # remove prepend length that has been added by global hidden states + hidden_states = hidden_states.transpose(1, 2)[:, :, 1:] + hidden_states = self.postprocess_conv(hidden_states) + hidden_states + + if not return_dict: + return (hidden_states,) + + return Transformer2DModelOutput(sample=hidden_states) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index 7bc50b297566..f1d41b60d090 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -231,6 +231,10 @@ _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"] _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"] _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"] + _import_structure["stable_audio"] = [ + "StableAudioProjectionModel", + "StableAudioPipeline", + ] _import_structure["stable_cascade"] = [ "StableCascadeCombinedPipeline", "StableCascadeDecoderPipeline", @@ -533,6 +537,7 @@ from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline from .semantic_stable_diffusion import SemanticStableDiffusionPipeline from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline + from .stable_audio import StableAudioPipeline, StableAudioProjectionModel from .stable_cascade import ( StableCascadeCombinedPipeline, StableCascadeDecoderPipeline, diff --git a/src/diffusers/pipelines/stable_audio/__init__.py b/src/diffusers/pipelines/stable_audio/__init__.py new file mode 100644 index 000000000000..dfdd419ae991 --- /dev/null +++ b/src/diffusers/pipelines/stable_audio/__init__.py @@ -0,0 +1,50 @@ +from typing import TYPE_CHECKING + +from ...utils import ( + DIFFUSERS_SLOW_IMPORT, + OptionalDependencyNotAvailable, + _LazyModule, + get_objects_from_module, + is_torch_available, + is_transformers_available, + is_transformers_version, +) + + +_dummy_objects = {} +_import_structure = {} + +try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() +except OptionalDependencyNotAvailable: + from ...utils import dummy_torch_and_transformers_objects + + _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects)) +else: + _import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"] + _import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"] + + +if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: + try: + if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")): + raise OptionalDependencyNotAvailable() + except OptionalDependencyNotAvailable: + from ...utils.dummy_torch_and_transformers_objects import * + + else: + from .modeling_stable_audio import StableAudioProjectionModel + from .pipeline_stable_audio import StableAudioPipeline + +else: + import sys + + sys.modules[__name__] = _LazyModule( + __name__, + globals()["__file__"], + _import_structure, + module_spec=__spec__, + ) + for name, value in _dummy_objects.items(): + setattr(sys.modules[__name__], name, value) diff --git a/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py new file mode 100644 index 000000000000..b8f8a705de21 --- /dev/null +++ b/src/diffusers/pipelines/stable_audio/modeling_stable_audio.py @@ -0,0 +1,158 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from math import pi +from typing import Optional + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from ...configuration_utils import ConfigMixin, register_to_config +from ...models.modeling_utils import ModelMixin +from ...utils import BaseOutput, logging + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class StableAudioPositionalEmbedding(nn.Module): + """Used for continuous time""" + + def __init__(self, dim: int): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, times: torch.Tensor) -> torch.Tensor: + times = times[..., None] + freqs = times * self.weights[None] * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((times, fouriered), dim=-1) + return fouriered + + +@dataclass +class StableAudioProjectionModelOutput(BaseOutput): + """ + Args: + Class for StableAudio projection layer's outputs. + text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder. + seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio start hidden states. + seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*): + Sequence of hidden-states obtained by linearly projecting the audio end hidden states. + """ + + text_hidden_states: Optional[torch.Tensor] = None + seconds_start_hidden_states: Optional[torch.Tensor] = None + seconds_end_hidden_states: Optional[torch.Tensor] = None + + +class StableAudioNumberConditioner(nn.Module): + """ + A simple linear projection model to map numbers to a latent space. + + Args: + number_embedding_dim (`int`): + Dimensionality of the number embeddings. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + internal_dim (`int`): + Dimensionality of the intermediate number hidden states. + """ + + def __init__( + self, + number_embedding_dim, + min_value, + max_value, + internal_dim: Optional[int] = 256, + ): + super().__init__() + self.time_positional_embedding = nn.Sequential( + StableAudioPositionalEmbedding(internal_dim), + nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim), + ) + + self.number_embedding_dim = number_embedding_dim + self.min_value = min_value + self.max_value = max_value + + def forward( + self, + floats: torch.Tensor, + ): + floats = floats.clamp(self.min_value, self.max_value) + + normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value) + + # Cast floats to same type as embedder + embedder_dtype = next(self.time_positional_embedding.parameters()).dtype + normalized_floats = normalized_floats.to(embedder_dtype) + + embedding = self.time_positional_embedding(normalized_floats) + float_embeds = embedding.view(-1, 1, self.number_embedding_dim) + + return float_embeds + + +class StableAudioProjectionModel(ModelMixin, ConfigMixin): + """ + A simple linear projection model to map the conditioning values to a shared latent space. + + Args: + text_encoder_dim (`int`): + Dimensionality of the text embeddings from the text encoder (T5). + conditioning_dim (`int`): + Dimensionality of the output conditioning tensors. + min_value (`int`): + The minimum value of the seconds number conditioning modules. + max_value (`int`): + The maximum value of the seconds number conditioning modules + """ + + @register_to_config + def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value): + super().__init__() + self.text_projection = ( + nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim) + ) + self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value) + + def forward( + self, + text_hidden_states: Optional[torch.Tensor] = None, + start_seconds: Optional[torch.Tensor] = None, + end_seconds: Optional[torch.Tensor] = None, + ): + text_hidden_states = ( + text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states) + ) + seconds_start_hidden_states = ( + start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds) + ) + seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds) + + return StableAudioProjectionModelOutput( + text_hidden_states=text_hidden_states, + seconds_start_hidden_states=seconds_start_hidden_states, + seconds_end_hidden_states=seconds_end_hidden_states, + ) diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py new file mode 100644 index 000000000000..779c4f0dd173 --- /dev/null +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -0,0 +1,745 @@ +# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, + T5TokenizerFast, +) + +from ...models import AutoencoderOobleck, StableAudioDiTModel +from ...models.embeddings import get_1d_rotary_pos_embed +from ...schedulers import EDMDPMSolverMultistepScheduler +from ...utils import ( + logging, + replace_example_docstring, +) +from ...utils.torch_utils import randn_tensor +from ..pipeline_utils import AudioPipelineOutput, DiffusionPipeline +from .modeling_stable_audio import StableAudioProjectionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import scipy + >>> import torch + >>> import soundfile as sf + >>> from diffusers import StableAudioPipeline + + >>> repo_id = "ylacombe/stable-audio-1.0" # TODO (YL): change once set + >>> pipe = StableAudioPipeline.from_pretrained(repo_id, torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> # define the prompts + >>> prompt = "The sound of a hammer hitting a wooden surface." + >>> negative_prompt = "Low quality." + + >>> # set the seed for generator + >>> generator = torch.Generator("cuda").manual_seed(0) + + >>> # run the generation + >>> audio = pipe( + ... prompt, + ... negative_prompt=negative_prompt, + ... num_inference_steps=200, + ... audio_end_in_s=10.0, + ... num_waveforms_per_prompt=3, + ... generator=generator, + ... ).audios + + >>> output = audio[0].T.float().cpu().numpy() + >>> sf.write("hammer.wav", output, pipe.vae.sampling_rate) + ``` +""" + + +class StableAudioPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-audio generation using StableAudio. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + Args: + vae ([`AutoencoderOobleck`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.T5EncoderModel`]): + Frozen text-encoder. StableAudio uses the encoder of + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [google-t5/t5-base](https://huggingface.co/google-t5/t5-base) variant. + projection_model ([`StableAudioProjectionModel`]): + A trained model used to linearly project the hidden-states from the text encoder model and the start and + end seconds. The projected hidden-states from the encoder and the conditional seconds are concatenated to + give the input to the transformer model. + tokenizer ([`~transformers.T5Tokenizer`]): + Tokenizer to tokenize text for the frozen text-encoder. + transformer ([`StableAudioDiTModel`]): + A `StableAudioDiTModel` to denoise the encoded audio latents. + scheduler ([`EDMDPMSolverMultistepScheduler`]): + A scheduler to be used in combination with `transformer` to denoise the encoded audio latents. + """ + + model_cpu_offload_seq = "text_encoder->projection_model->transformer->vae" + + def __init__( + self, + vae: AutoencoderOobleck, + text_encoder: T5EncoderModel, + projection_model: StableAudioProjectionModel, + tokenizer: Union[T5Tokenizer, T5TokenizerFast], + transformer: StableAudioDiTModel, + scheduler: EDMDPMSolverMultistepScheduler, + ): + super().__init__() + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + projection_model=projection_model, + tokenizer=tokenizer, + transformer=transformer, + scheduler=scheduler, + ) + self.rotary_embed_dim = self.transformer.config.attention_head_dim // 2 + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.pipeline_utils.StableDiffusionMixin.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def encode_prompt( + self, + prompt, + device, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + ): + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # 1. Tokenize text + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + f"The following part of your input was truncated because {self.text_encoder.config.model_type} can " + f"only handle sequences up to {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + text_input_ids = text_input_ids.to(device) + attention_mask = attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + prompt_embeds = self.text_encoder( + text_input_ids, + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + if do_classifier_free_guidance and negative_prompt is not None: + uncond_tokens: List[str] + if type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # 1. Tokenize text + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + uncond_input_ids = uncond_input.input_ids.to(device) + negative_attention_mask = uncond_input.attention_mask.to(device) + + # 2. Text encoder forward + self.text_encoder.eval() + negative_prompt_embeds = self.text_encoder( + uncond_input_ids, + attention_mask=negative_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if negative_attention_mask is not None: + # set the masked tokens to the null embed + negative_prompt_embeds = torch.where( + negative_attention_mask.to(torch.bool).unsqueeze(2), negative_prompt_embeds, 0.0 + ) + + # 3. Project prompt_embeds and negative_prompt_embeds + if do_classifier_free_guidance and negative_prompt_embeds is not None: + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the negative and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + if attention_mask is not None and negative_attention_mask is None: + negative_attention_mask = torch.ones_like(attention_mask) + elif attention_mask is None and negative_attention_mask is not None: + attention_mask = torch.ones_like(negative_attention_mask) + + if attention_mask is not None: + attention_mask = torch.cat([negative_attention_mask, attention_mask]) + + prompt_embeds = self.projection_model( + text_hidden_states=prompt_embeds, + ).text_hidden_states + if attention_mask is not None: + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).to(prompt_embeds.dtype) + + return prompt_embeds + + def encode_duration( + self, + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance, + batch_size, + ): + audio_start_in_s = audio_start_in_s if isinstance(audio_start_in_s, list) else [audio_start_in_s] + audio_end_in_s = audio_end_in_s if isinstance(audio_end_in_s, list) else [audio_end_in_s] + + if len(audio_start_in_s) == 1: + audio_start_in_s = audio_start_in_s * batch_size + if len(audio_end_in_s) == 1: + audio_end_in_s = audio_end_in_s * batch_size + + # Cast the inputs to floats + audio_start_in_s = [float(x) for x in audio_start_in_s] + audio_start_in_s = torch.tensor(audio_start_in_s).to(device) + + audio_end_in_s = [float(x) for x in audio_end_in_s] + audio_end_in_s = torch.tensor(audio_end_in_s).to(device) + + projection_output = self.projection_model( + start_seconds=audio_start_in_s, + end_seconds=audio_end_in_s, + ) + seconds_start_hidden_states = projection_output.seconds_start_hidden_states + seconds_end_hidden_states = projection_output.seconds_end_hidden_states + + # For classifier free guidance, we need to do two forward passes. + # Here we repeat the audio hidden states to avoid doing two forward passes + if do_classifier_free_guidance: + seconds_start_hidden_states = torch.cat([seconds_start_hidden_states, seconds_start_hidden_states], dim=0) + seconds_end_hidden_states = torch.cat([seconds_end_hidden_states, seconds_end_hidden_states], dim=0) + + return seconds_start_hidden_states, seconds_end_hidden_states + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + attention_mask=None, + negative_attention_mask=None, + initial_audio_waveforms=None, + initial_audio_sampling_rate=None, + ): + if audio_end_in_s < audio_start_in_s: + raise ValueError( + f"`audio_end_in_s={audio_end_in_s}' must be higher than 'audio_start_in_s={audio_start_in_s}` but " + ) + + if ( + audio_start_in_s < self.projection_model.config.min_value + or audio_start_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_start_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_start_in_s}." + ) + + if ( + audio_end_in_s < self.projection_model.config.min_value + or audio_end_in_s > self.projection_model.config.max_value + ): + raise ValueError( + f"`audio_end_in_s` must be greater than or equal to {self.projection_model.config.min_value}, and lower than or equal to {self.projection_model.config.max_value} but " + f"is {audio_end_in_s}." + ) + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and (prompt_embeds is None): + raise ValueError( + "Provide either `prompt`, or `prompt_embeds`. Cannot leave" + "`prompt` undefined without specifying `prompt_embeds`." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if attention_mask is not None and attention_mask.shape != prompt_embeds.shape[:2]: + raise ValueError( + "`attention_mask should have the same batch size and sequence length as `prompt_embeds`, but got:" + f"`attention_mask: {attention_mask.shape} != `prompt_embeds` {prompt_embeds.shape}" + ) + + if initial_audio_sampling_rate is None and initial_audio_waveforms is not None: + raise ValueError( + "`initial_audio_waveforms' is provided but the sampling rate is not. Make sure to pass `initial_audio_sampling_rate`." + ) + + if initial_audio_sampling_rate is not None and initial_audio_sampling_rate != self.vae.sampling_rate: + raise ValueError( + f"`initial_audio_sampling_rate` must be {self.vae.hop_length}' but is `{initial_audio_sampling_rate}`." + "Make sure to resample the `initial_audio_waveforms` and to correct the sampling rate. " + ) + + def prepare_latents( + self, + batch_size, + num_channels_vae, + sample_size, + dtype, + device, + generator, + latents=None, + initial_audio_waveforms=None, + num_waveforms_per_prompt=None, + audio_channels=None, + ): + shape = (batch_size, num_channels_vae, sample_size) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + + # encode the initial audio for use by the model + if initial_audio_waveforms is not None: + # check dimension + if initial_audio_waveforms.ndim == 2: + initial_audio_waveforms = initial_audio_waveforms.unsqueeze(1) + elif initial_audio_waveforms.ndim != 3: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but has `{initial_audio_waveforms.ndim}` dimensions" + ) + + audio_vae_length = self.transformer.config.sample_size * self.vae.hop_length + audio_shape = (batch_size // num_waveforms_per_prompt, audio_channels, audio_vae_length) + + # check num_channels + if initial_audio_waveforms.shape[1] == 1 and audio_channels == 2: + initial_audio_waveforms = initial_audio_waveforms.repeat(1, 2, 1) + elif initial_audio_waveforms.shape[1] == 2 and audio_channels == 1: + initial_audio_waveforms = initial_audio_waveforms.mean(1, keepdim=True) + + if initial_audio_waveforms.shape[:2] != audio_shape[:2]: + raise ValueError( + f"`initial_audio_waveforms` must be of shape `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)` but is of shape `{initial_audio_waveforms.shape}`" + ) + + # crop or pad + audio_length = initial_audio_waveforms.shape[-1] + if audio_length < audio_vae_length: + logger.warning( + f"The provided input waveform is shorter ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be padded." + ) + elif audio_length > audio_vae_length: + logger.warning( + f"The provided input waveform is longer ({audio_length}) than the required audio length ({audio_vae_length}) of the model and will thus be cropped." + ) + + audio = initial_audio_waveforms.new_zeros(audio_shape) + audio[:, :, : min(audio_length, audio_vae_length)] = initial_audio_waveforms[:, :, :audio_vae_length] + + encoded_audio = self.vae.encode(audio).latent_dist.sample(generator) + encoded_audio = encoded_audio.repeat((num_waveforms_per_prompt, 1, 1)) + latents = encoded_audio + latents + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + audio_end_in_s: Optional[float] = None, + audio_start_in_s: Optional[float] = 0.0, + num_inference_steps: int = 100, + guidance_scale: float = 7.0, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_waveforms_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.Tensor] = None, + initial_audio_waveforms: Optional[torch.Tensor] = None, + initial_audio_sampling_rate: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.Tensor] = None, + negative_prompt_embeds: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.LongTensor] = None, + negative_attention_mask: Optional[torch.LongTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.Tensor], None]] = None, + callback_steps: Optional[int] = 1, + output_type: Optional[str] = "pt", + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide audio generation. If not defined, you need to pass `prompt_embeds`. + audio_end_in_s (`float`, *optional*, defaults to 47.55): + Audio end index in seconds. + audio_start_in_s (`float`, *optional*, defaults to 0): + Audio start index in seconds. + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality audio at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.0): + A higher guidance scale value encourages the model to generate audio that is closely linked to the text + `prompt` at the expense of lower sound quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in audio generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_waveforms_per_prompt (`int`, *optional*, defaults to 1): + The number of waveforms to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.Tensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for audio + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + initial_audio_waveforms (`torch.Tensor`, *optional*): + Optional initial audio waveforms to use as the initial audio waveform for generation. Must be of shape + `(batch_size, num_channels, audio_length)` or `(batch_size, audio_length)`, where `batch_size` + corresponds to the number of prompts passed to the model. + initial_audio_sampling_rate (`int`, *optional*): + Sampling rate of the `initial_audio_waveforms`, if they are provided. Must be the same as the model. + prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed text embeddings from the text encoder model. Can be used to easily tweak text inputs, + *e.g.* prompt weighting. If not provided, text embeddings will be computed from `prompt` input + argument. + negative_prompt_embeds (`torch.Tensor`, *optional*): + Pre-computed negative text embeddings from the text encoder model. Can be used to easily tweak text + inputs, *e.g.* prompt weighting. If not provided, negative_prompt_embeds will be computed from + `negative_prompt` input argument. + attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `prompt_embeds`. If not provided, attention mask will + be computed from `prompt` input argument. + negative_attention_mask (`torch.LongTensor`, *optional*): + Pre-computed attention mask to be applied to the `negative_text_audio_duration_embeds`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.Tensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + output_type (`str`, *optional*, defaults to `"pt"`): + The output format of the generated audio. Choose between `"np"` to return a NumPy `np.ndarray` or + `"pt"` to return a PyTorch `torch.Tensor` object. Set to `"latent"` to return the latent diffusion + model (LDM) output. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated audio. + """ + # 0. Convert audio input length from seconds to latent length + downsample_ratio = self.vae.hop_length + + max_audio_length_in_s = self.transformer.config.sample_size * downsample_ratio / self.vae.config.sampling_rate + if audio_end_in_s is None: + audio_end_in_s = max_audio_length_in_s + + if audio_end_in_s - audio_start_in_s > max_audio_length_in_s: + raise ValueError( + f"The total audio length requested ({audio_end_in_s-audio_start_in_s}s) is longer than the model maximum possible length ({max_audio_length_in_s}). Make sure that 'audio_end_in_s-audio_start_in_s<={max_audio_length_in_s}'." + ) + + waveform_start = int(audio_start_in_s * self.vae.config.sampling_rate) + waveform_end = int(audio_end_in_s * self.vae.config.sampling_rate) + waveform_length = int(self.transformer.config.sample_size) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + audio_start_in_s, + audio_end_in_s, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + initial_audio_waveforms, + initial_audio_sampling_rate, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self.encode_prompt( + prompt, + device, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + attention_mask, + negative_attention_mask, + ) + + # Encode duration + seconds_start_hidden_states, seconds_end_hidden_states = self.encode_duration( + audio_start_in_s, + audio_end_in_s, + device, + do_classifier_free_guidance and (negative_prompt is not None or negative_prompt_embeds is not None), + batch_size, + ) + + # Create text_audio_duration_embeds and audio_duration_embeds + text_audio_duration_embeds = torch.cat( + [prompt_embeds, seconds_start_hidden_states, seconds_end_hidden_states], dim=1 + ) + + audio_duration_embeds = torch.cat([seconds_start_hidden_states, seconds_end_hidden_states], dim=2) + + # In case of classifier free guidance without negative prompt, we need to create unconditional embeddings and + # to concatenate it to the embeddings + if do_classifier_free_guidance and negative_prompt_embeds is None and negative_prompt is None: + negative_text_audio_duration_embeds = torch.zeros_like( + text_audio_duration_embeds, device=text_audio_duration_embeds.device + ) + text_audio_duration_embeds = torch.cat( + [negative_text_audio_duration_embeds, text_audio_duration_embeds], dim=0 + ) + audio_duration_embeds = torch.cat([audio_duration_embeds, audio_duration_embeds], dim=0) + + bs_embed, seq_len, hidden_size = text_audio_duration_embeds.shape + # duplicate audio_duration_embeds and text_audio_duration_embeds for each generation per prompt, using mps friendly method + text_audio_duration_embeds = text_audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + text_audio_duration_embeds = text_audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, seq_len, hidden_size + ) + + audio_duration_embeds = audio_duration_embeds.repeat(1, num_waveforms_per_prompt, 1) + audio_duration_embeds = audio_duration_embeds.view( + bs_embed * num_waveforms_per_prompt, -1, audio_duration_embeds.shape[-1] + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_vae = self.transformer.config.in_channels + latents = self.prepare_latents( + batch_size * num_waveforms_per_prompt, + num_channels_vae, + waveform_length, + text_audio_duration_embeds.dtype, + device, + generator, + latents, + initial_audio_waveforms, + num_waveforms_per_prompt, + audio_channels=self.vae.config.audio_channels, + ) + + # 6. Prepare extra step kwargs + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Prepare rotary positional embedding + rotary_embedding = get_1d_rotary_pos_embed( + self.rotary_embed_dim, + latents.shape[2] + audio_duration_embeds.shape[1], + use_real=True, + repeat_interleave_real=False, + ) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + # 9. Post-processing + if not output_type == "latent": + audio = self.vae.decode(latents).sample + else: + return AudioPipelineOutput(audios=latents) + + audio = audio[:, :, waveform_start:waveform_end] + + if output_type == "np": + audio = audio.cpu().float().numpy() + + self.maybe_free_model_hooks() + + if not return_dict: + return (audio,) + + return AudioPipelineOutput(audios=audio) diff --git a/src/diffusers/schedulers/__init__.py b/src/diffusers/schedulers/__init__.py index dfee479bfa96..696e9c3ad5d5 100644 --- a/src/diffusers/schedulers/__init__.py +++ b/src/diffusers/schedulers/__init__.py @@ -118,6 +118,7 @@ _dummy_modules.update(get_objects_from_module(dummy_torch_and_torchsde_objects)) else: + _import_structure["scheduling_cosine_dpmsolver_multistep"] = ["CosineDPMSolverMultistepScheduler"] _import_structure["scheduling_dpmsolver_sde"] = ["DPMSolverSDEScheduler"] if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: @@ -205,6 +206,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_torchsde_objects import * # noqa F403 else: + from .scheduling_cosine_dpmsolver_multistep import CosineDPMSolverMultistepScheduler from .scheduling_dpmsolver_sde import DPMSolverSDEScheduler else: diff --git a/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py new file mode 100644 index 000000000000..ab56650dbac5 --- /dev/null +++ b/src/diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py @@ -0,0 +1,572 @@ +# Copyright 2024 TSAIL Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This file is strongly influenced by https://github.com/LuChengTHU/dpm-solver and https://github.com/NVlabs/edm + +import math +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from ..configuration_utils import ConfigMixin, register_to_config +from .scheduling_dpmsolver_sde import BrownianTreeNoiseSampler +from .scheduling_utils import SchedulerMixin, SchedulerOutput + + +class CosineDPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin): + """ + Implements a variant of `DPMSolverMultistepScheduler` with cosine schedule, proposed by Nichol and Dhariwal (2021). + This scheduler was used in Stable Audio Open [1]. + + [1] Evans, Parker, et al. "Stable Audio Open" https://arxiv.org/abs/2407.14358 + + This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic + methods the library implements for all schedulers such as loading and saving. + + Args: + sigma_min (`float`, *optional*, defaults to 0.3): + Minimum noise magnitude in the sigma schedule. This was set to 0.3 in Stable Audio Open [1]. + sigma_max (`float`, *optional*, defaults to 500): + Maximum noise magnitude in the sigma schedule. This was set to 500 in Stable Audio Open [1]. + sigma_data (`float`, *optional*, defaults to 1.0): + The standard deviation of the data distribution. This is set to 1.0 in Stable Audio Open [1]. + sigma_schedule (`str`, *optional*, defaults to `exponential`): + Sigma schedule to compute the `sigmas`. By default, we the schedule introduced in the EDM paper + (https://arxiv.org/abs/2206.00364). Other acceptable value is "exponential". The exponential schedule was + incorporated in this model: https://huggingface.co/stabilityai/cosxl. + num_train_timesteps (`int`, defaults to 1000): + The number of diffusion steps to train the model. + solver_order (`int`, defaults to 2): + The DPMSolver order which can be `1` or `2`. It is recommended to use `solver_order=2`. + prediction_type (`str`, defaults to `v_prediction`, *optional*): + Prediction type of the scheduler function; can be `epsilon` (predicts the noise of the diffusion process), + `sample` (directly predicts the noisy sample`) or `v_prediction` (see section 2.4 of [Imagen + Video](https://imagen.research.google/video/paper.pdf) paper). + solver_type (`str`, defaults to `midpoint`): + Solver type for the second-order solver; can be `midpoint` or `heun`. The solver type slightly affects the + sample quality, especially for a small number of steps. It is recommended to use `midpoint` solvers. + lower_order_final (`bool`, defaults to `True`): + Whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. This can + stabilize the sampling of DPMSolver for steps < 15, especially for steps <= 10. + euler_at_final (`bool`, defaults to `False`): + Whether to use Euler's method in the final step. It is a trade-off between numerical stability and detail + richness. This can stabilize the sampling of the SDE variant of DPMSolver for small number of inference + steps, but sometimes may result in blurring. + final_sigmas_type (`str`, defaults to `"zero"`): + The final `sigma` value for the noise schedule during the sampling process. If `"sigma_min"`, the final + sigma is the same as the last sigma in the training schedule. If `zero`, the final sigma is set to 0. + """ + + _compatibles = [] + order = 1 + + @register_to_config + def __init__( + self, + sigma_min: float = 0.3, + sigma_max: float = 500, + sigma_data: float = 1.0, + sigma_schedule: str = "exponential", + num_train_timesteps: int = 1000, + solver_order: int = 2, + prediction_type: str = "v_prediction", + rho: float = 7.0, + solver_type: str = "midpoint", + lower_order_final: bool = True, + euler_at_final: bool = False, + final_sigmas_type: Optional[str] = "zero", # "zero", "sigma_min" + ): + if solver_type not in ["midpoint", "heun"]: + if solver_type in ["logrho", "bh1", "bh2"]: + self.register_to_config(solver_type="midpoint") + else: + raise NotImplementedError(f"{solver_type} is not implemented for {self.__class__}") + + ramp = torch.linspace(0, 1, num_train_timesteps) + if sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + self.timesteps = self.precondition_noise(sigmas) + + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + + # setable values + self.num_inference_steps = None + self.model_outputs = [None] * solver_order + self.lower_order_nums = 0 + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + @property + def init_noise_sigma(self): + # standard deviation of the initial noise distribution + return (self.config.sigma_max**2 + 1) ** 0.5 + + @property + def step_index(self): + """ + The index counter for current timestep. It will increase 1 after each scheduler step. + """ + return self._step_index + + @property + def begin_index(self): + """ + The index for the first timestep. It should be set from pipeline with `set_begin_index` method. + """ + return self._begin_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index + def set_begin_index(self, begin_index: int = 0): + """ + Sets the begin index for the scheduler. This function should be run from pipeline before the inference. + + Args: + begin_index (`int`): + The begin index for the scheduler. + """ + self._begin_index = begin_index + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_inputs + def precondition_inputs(self, sample, sigma): + c_in = 1 / ((sigma**2 + self.config.sigma_data**2) ** 0.5) + scaled_sample = sample * c_in + return scaled_sample + + def precondition_noise(self, sigma): + if not isinstance(sigma, torch.Tensor): + sigma = torch.tensor([sigma]) + + return sigma.atan() / math.pi * 2 + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.precondition_outputs + def precondition_outputs(self, sample, model_output, sigma): + sigma_data = self.config.sigma_data + c_skip = sigma_data**2 / (sigma**2 + sigma_data**2) + + if self.config.prediction_type == "epsilon": + c_out = sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + elif self.config.prediction_type == "v_prediction": + c_out = -sigma * sigma_data / (sigma**2 + sigma_data**2) ** 0.5 + else: + raise ValueError(f"Prediction type {self.config.prediction_type} is not supported.") + + denoised = c_skip * sample + c_out * model_output + + return denoised + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler.scale_model_input + def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. + + Args: + sample (`torch.Tensor`): + The input sample. + timestep (`int`, *optional*): + The current timestep in the diffusion chain. + + Returns: + `torch.Tensor`: + A scaled input sample. + """ + if self.step_index is None: + self._init_step_index(timestep) + + sigma = self.sigmas[self.step_index] + sample = self.precondition_inputs(sample, sigma) + + self.is_scale_input_called = True + return sample + + def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain (to be run before inference). + + Args: + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + """ + + self.num_inference_steps = num_inference_steps + + ramp = torch.linspace(0, 1, self.num_inference_steps) + if self.config.sigma_schedule == "karras": + sigmas = self._compute_karras_sigmas(ramp) + elif self.config.sigma_schedule == "exponential": + sigmas = self._compute_exponential_sigmas(ramp) + + sigmas = sigmas.to(dtype=torch.float32, device=device) + self.timesteps = self.precondition_noise(sigmas) + + if self.config.final_sigmas_type == "sigma_min": + sigma_last = self.config.sigma_min + elif self.config.final_sigmas_type == "zero": + sigma_last = 0 + else: + raise ValueError( + f"`final_sigmas_type` must be one of 'zero', or 'sigma_min', but got {self.config.final_sigmas_type}" + ) + + self.sigmas = torch.cat([sigmas, torch.tensor([sigma_last], dtype=torch.float32, device=device)]) + + self.model_outputs = [ + None, + ] * self.config.solver_order + self.lower_order_nums = 0 + + # add an index counter for schedulers that allow duplicated timesteps + self._step_index = None + self._begin_index = None + self.sigmas = self.sigmas.to("cpu") # to avoid too much CPU/GPU communication + + # if a noise sampler is used, reinitialise it + self.noise_sampler = None + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_karras_sigmas + def _compute_karras_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Constructs the noise schedule of Karras et al. (2022).""" + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + + rho = self.config.rho + min_inv_rho = sigma_min ** (1 / rho) + max_inv_rho = sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return sigmas + + # Copied from diffusers.schedulers.scheduling_edm_euler.EDMEulerScheduler._compute_exponential_sigmas + def _compute_exponential_sigmas(self, ramp, sigma_min=None, sigma_max=None) -> torch.Tensor: + """Implementation closely follows k-diffusion. + + https://github.com/crowsonkb/k-diffusion/blob/6ab5146d4a5ef63901326489f31f1d8e7dd36b48/k_diffusion/sampling.py#L26 + """ + sigma_min = sigma_min or self.config.sigma_min + sigma_max = sigma_max or self.config.sigma_max + sigmas = torch.linspace(math.log(sigma_min), math.log(sigma_max), len(ramp)).exp().flip(0) + return sigmas + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t + def _sigma_to_t(self, sigma, log_sigmas): + # get log sigma + log_sigma = np.log(np.maximum(sigma, 1e-10)) + + # get distribution + dists = log_sigma - log_sigmas[:, np.newaxis] + + # get sigmas range + low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2) + high_idx = low_idx + 1 + + low = log_sigmas[low_idx] + high = log_sigmas[high_idx] + + # interpolate sigmas + w = (low - log_sigma) / (low - high) + w = np.clip(w, 0, 1) + + # transform interpolation to time range + t = (1 - w) * low_idx + w * high_idx + t = t.reshape(sigma.shape) + return t + + def _sigma_to_alpha_sigma_t(self, sigma): + alpha_t = torch.tensor(1) # Inputs are pre-scaled before going into unet, so alpha_t = 1 + sigma_t = sigma + + return alpha_t, sigma_t + + def convert_model_output( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + ) -> torch.Tensor: + """ + Convert the model output to the corresponding type the DPMSolver/DPMSolver++ algorithm needs. DPM-Solver is + designed to discretize an integral of the noise prediction model, and DPM-Solver++ is designed to discretize an + integral of the data prediction model. + + + + The algorithm and model type are decoupled. You can use either DPMSolver or DPMSolver++ for both noise + prediction and data prediction models. + + + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The converted model output. + """ + sigma = self.sigmas[self.step_index] + x0_pred = self.precondition_outputs(sample, model_output, sigma) + + return x0_pred + + def dpm_solver_first_order_update( + self, + model_output: torch.Tensor, + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the first-order DPMSolver (equivalent to DDIM). + + Args: + model_output (`torch.Tensor`): + The direct output from the learned diffusion model. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s = self.sigmas[self.step_index + 1], self.sigmas[self.step_index] + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s, sigma_s = self._sigma_to_alpha_sigma_t(sigma_s) + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s = torch.log(alpha_s) - torch.log(sigma_s) + + h = lambda_t - lambda_s + assert noise is not None + x_t = ( + (sigma_t / sigma_s * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * model_output + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + def multistep_dpm_solver_second_order_update( + self, + model_output_list: List[torch.Tensor], + sample: torch.Tensor = None, + noise: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """ + One step for the second-order multistep DPMSolver. + + Args: + model_output_list (`List[torch.Tensor]`): + The direct outputs from learned diffusion model at current and latter timesteps. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + + Returns: + `torch.Tensor`: + The sample tensor at the previous timestep. + """ + sigma_t, sigma_s0, sigma_s1 = ( + self.sigmas[self.step_index + 1], + self.sigmas[self.step_index], + self.sigmas[self.step_index - 1], + ) + + alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t) + alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0) + alpha_s1, sigma_s1 = self._sigma_to_alpha_sigma_t(sigma_s1) + + lambda_t = torch.log(alpha_t) - torch.log(sigma_t) + lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0) + lambda_s1 = torch.log(alpha_s1) - torch.log(sigma_s1) + + m0, m1 = model_output_list[-1], model_output_list[-2] + + h, h_0 = lambda_t - lambda_s0, lambda_s0 - lambda_s1 + r0 = h_0 / h + D0, D1 = m0, (1.0 / r0) * (m0 - m1) + + # sde-dpmsolver++ + assert noise is not None + if self.config.solver_type == "midpoint": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + 0.5 * (alpha_t * (1 - torch.exp(-2.0 * h))) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + elif self.config.solver_type == "heun": + x_t = ( + (sigma_t / sigma_s0 * torch.exp(-h)) * sample + + (alpha_t * (1 - torch.exp(-2.0 * h))) * D0 + + (alpha_t * ((1.0 - torch.exp(-2.0 * h)) / (-2.0 * h) + 1.0)) * D1 + + sigma_t * torch.sqrt(1.0 - torch.exp(-2 * h)) * noise + ) + + return x_t + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.index_for_timestep + def index_for_timestep(self, timestep, schedule_timesteps=None): + if schedule_timesteps is None: + schedule_timesteps = self.timesteps + + index_candidates = (schedule_timesteps == timestep).nonzero() + + if len(index_candidates) == 0: + step_index = len(self.timesteps) - 1 + # The sigma index that is taken for the **very** first `step` + # is always the second index (or the last index if there is only 1) + # This way we can ensure we don't accidentally skip a sigma in + # case we start in the middle of the denoising schedule (e.g. for image-to-image) + elif len(index_candidates) > 1: + step_index = index_candidates[1].item() + else: + step_index = index_candidates[0].item() + + return step_index + + # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._init_step_index + def _init_step_index(self, timestep): + """ + Initialize the step_index counter for the scheduler. + """ + + if self.begin_index is None: + if isinstance(timestep, torch.Tensor): + timestep = timestep.to(self.timesteps.device) + self._step_index = self.index_for_timestep(timestep) + else: + self._step_index = self._begin_index + + def step( + self, + model_output: torch.Tensor, + timestep: Union[int, torch.Tensor], + sample: torch.Tensor, + generator=None, + return_dict: bool = True, + ) -> Union[SchedulerOutput, Tuple]: + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the sample with + the multistep DPMSolver. + + Args: + model_output (`torch.Tensor`): + The direct output from learned diffusion model. + timestep (`int`): + The current discrete timestep in the diffusion chain. + sample (`torch.Tensor`): + A current instance of a sample created by the diffusion process. + generator (`torch.Generator`, *optional*): + A random number generator. + return_dict (`bool`): + Whether or not to return a [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.SchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.scheduling_utils.SchedulerOutput`] is returned, otherwise a + tuple is returned where the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if self.step_index is None: + self._init_step_index(timestep) + + # Improve numerical stability for small number of steps + lower_order_final = (self.step_index == len(self.timesteps) - 1) and ( + self.config.euler_at_final + or (self.config.lower_order_final and len(self.timesteps) < 15) + or self.config.final_sigmas_type == "zero" + ) + lower_order_second = ( + (self.step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 + ) + + model_output = self.convert_model_output(model_output, sample=sample) + for i in range(self.config.solver_order - 1): + self.model_outputs[i] = self.model_outputs[i + 1] + self.model_outputs[-1] = model_output + + if self.noise_sampler is None: + seed = None + if generator is not None: + seed = ( + [g.initial_seed() for g in generator] if isinstance(generator, list) else generator.initial_seed() + ) + self.noise_sampler = BrownianTreeNoiseSampler( + model_output, sigma_min=self.config.sigma_min, sigma_max=self.config.sigma_max, seed=seed + ) + noise = self.noise_sampler(self.sigmas[self.step_index], self.sigmas[self.step_index + 1]).to( + model_output.device + ) + + if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: + prev_sample = self.dpm_solver_first_order_update(model_output, sample=sample, noise=noise) + elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: + prev_sample = self.multistep_dpm_solver_second_order_update(self.model_outputs, sample=sample, noise=noise) + + if self.lower_order_nums < self.config.solver_order: + self.lower_order_nums += 1 + + # upon completion increase step index by one + self._step_index += 1 + + if not return_dict: + return (prev_sample,) + + return SchedulerOutput(prev_sample=prev_sample) + + # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise + def add_noise( + self, + original_samples: torch.Tensor, + noise: torch.Tensor, + timesteps: torch.Tensor, + ) -> torch.Tensor: + # Make sure sigmas and timesteps have the same device and dtype as original_samples + sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) + if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): + # mps does not support float64 + schedule_timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) + timesteps = timesteps.to(original_samples.device, dtype=torch.float32) + else: + schedule_timesteps = self.timesteps.to(original_samples.device) + timesteps = timesteps.to(original_samples.device) + + # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index + if self.begin_index is None: + step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps] + elif self.step_index is not None: + # add_noise is called after first denoising step (for inpainting) + step_indices = [self.step_index] * timesteps.shape[0] + else: + # add noise is called before first denoising step to create initial latent(img2img) + step_indices = [self.begin_index] * timesteps.shape[0] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < len(original_samples.shape): + sigma = sigma.unsqueeze(-1) + + noisy_samples = original_samples + noise * sigma + return noisy_samples + + def __len__(self): + return self.config.num_train_timesteps diff --git a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py index 21ae1df00a88..c49e8e9a191a 100644 --- a/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py @@ -134,7 +134,7 @@ def __init__( self.timesteps = self.precondition_noise(sigmas) - self.sigmas = self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) + self.sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)]) # setable values self.num_inference_steps = None diff --git a/src/diffusers/utils/dummy_pt_objects.py b/src/diffusers/utils/dummy_pt_objects.py index 3ead6fd99d10..230b0b29b2c2 100644 --- a/src/diffusers/utils/dummy_pt_objects.py +++ b/src/diffusers/utils/dummy_pt_objects.py @@ -62,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class AutoencoderOobleck(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class AutoencoderTiny(metaclass=DummyObject): _backends = ["torch"] @@ -377,6 +392,21 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch"]) +class StableAudioDiTModel(metaclass=DummyObject): + _backends = ["torch"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch"]) + + class T2IAdapter(metaclass=DummyObject): _backends = ["torch"] diff --git a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py index a81bbb316f32..6ff14231b9cc 100644 --- a/src/diffusers/utils/dummy_torch_and_torchsde_objects.py +++ b/src/diffusers/utils/dummy_torch_and_torchsde_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class CosineDPMSolverMultistepScheduler(metaclass=DummyObject): + _backends = ["torch", "torchsde"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "torchsde"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "torchsde"]) + + class DPMSolverSDEScheduler(metaclass=DummyObject): _backends = ["torch", "torchsde"] diff --git a/src/diffusers/utils/dummy_torch_and_transformers_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_objects.py index 8e40e5128854..105c1a5def9d 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_objects.py @@ -992,6 +992,36 @@ def from_pretrained(cls, *args, **kwargs): requires_backends(cls, ["torch", "transformers"]) +class StableAudioPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + +class StableAudioProjectionModel(metaclass=DummyObject): + _backends = ["torch", "transformers"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers"]) + + class StableCascadeCombinedPipeline(metaclass=DummyObject): _backends = ["torch", "transformers"] diff --git a/tests/models/autoencoders/test_models_vae.py b/tests/models/autoencoders/test_models_vae.py index 0fc185b602a3..cff2ce63c8e3 100644 --- a/tests/models/autoencoders/test_models_vae.py +++ b/tests/models/autoencoders/test_models_vae.py @@ -18,12 +18,14 @@ import numpy as np import torch +from datasets import load_dataset from parameterized import parameterized from diffusers import ( AsymmetricAutoencoderKL, AutoencoderKL, AutoencoderKLTemporalDecoder, + AutoencoderOobleck, AutoencoderTiny, ConsistencyDecoderVAE, StableDiffusionPipeline, @@ -128,6 +130,18 @@ def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None): } +def get_autoencoder_oobleck_config(block_out_channels=None): + init_dict = { + "encoder_hidden_size": 12, + "decoder_channels": 12, + "decoder_input_channels": 6, + "audio_channels": 2, + "downsampling_ratios": [2, 4], + "channel_multiples": [1, 2], + } + return init_dict + + class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" @@ -480,6 +494,41 @@ def test_gradient_checkpointing(self): self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5)) +class AutoencoderOobleckTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): + model_class = AutoencoderOobleck + main_input_name = "sample" + base_precision = 1e-2 + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 2 + seq_len = 24 + + waveform = floats_tensor((batch_size, num_channels, seq_len)).to(torch_device) + + return {"sample": waveform, "sample_posterior": False} + + @property + def input_shape(self): + return (2, 24) + + @property + def output_shape(self): + return (2, 24) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = get_autoencoder_oobleck_config() + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_forward_signature(self): + pass + + def test_forward_with_norm_groups(self): + pass + + @slow class AutoencoderTinyIntegrationTests(unittest.TestCase): def tearDown(self): @@ -1100,3 +1149,118 @@ def test_vae_tiling(self): for shape in shapes: image = torch.zeros(shape, device=torch_device, dtype=pipe.vae.dtype) pipe.vae.decode(image) + + +@slow +class AutoencoderOobleckIntegrationTests(unittest.TestCase): + def tearDown(self): + # clean up the VRAM after each test + super().tearDown() + gc.collect() + backend_empty_cache(torch_device) + + def _load_datasamples(self, num_samples): + ds = load_dataset( + "hf-internal-testing/librispeech_asr_dummy", "clean", split="validation", trust_remote_code=True + ) + # automatic decoding with librispeech + speech_samples = ds.sort("id").select(range(num_samples))[:num_samples]["audio"] + + return torch.nn.utils.rnn.pad_sequence( + [torch.from_numpy(x["array"]) for x in speech_samples], batch_first=True + ) + + def get_audio(self, audio_sample_size=2097152, fp16=False): + dtype = torch.float16 if fp16 else torch.float32 + audio = self._load_datasamples(2).to(torch_device).to(dtype) + + # pad / crop to audio_sample_size + audio = torch.nn.functional.pad(audio[:, :audio_sample_size], pad=(0, audio_sample_size - audio.shape[-1])) + + # todo channel + audio = audio.unsqueeze(1).repeat(1, 2, 1).to(torch_device) + + return audio + + def get_oobleck_vae_model( + self, model_id="ylacombe/stable-audio-1.0", fp16=False + ): # TODO (YL): change repo id once moved + torch_dtype = torch.float16 if fp16 else torch.float32 + + model = AutoencoderOobleck.from_pretrained( + model_id, + subfolder="vae", + torch_dtype=torch_dtype, + ) + model.to(torch_device) + + return model + + def get_generator(self, seed=0): + generator_device = "cpu" if not torch_device.startswith("cuda") else "cuda" + if torch_device != "mps": + return torch.Generator(device=generator_device).manual_seed(seed) + return torch.manual_seed(seed) + + @parameterized.expand( + [ + # fmt: off + [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192], + [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196], + # fmt: on + ] + ) + def test_stable_diffusion(self, seed, expected_slice, expected_mean_absolute_diff): + model = self.get_oobleck_vae_model() + audio = self.get_audio() + generator = self.get_generator(seed) + + with torch.no_grad(): + sample = model(audio, generator=generator, sample_posterior=True).sample + + assert sample.shape == audio.shape + assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6 + + output_slice = sample[-1, 1, 5:10].cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-5) + + def test_stable_diffusion_mode(self): + model = self.get_oobleck_vae_model() + audio = self.get_audio() + + with torch.no_grad(): + sample = model(audio, sample_posterior=False).sample + + assert sample.shape == audio.shape + + @parameterized.expand( + [ + # fmt: off + [33, [1.193e-4, 6.56e-05, 1.314e-4, 3.80e-05, -4.01e-06], 0.001192], + [44, [2.77e-05, -2.65e-05, 1.18e-05, -6.94e-05, -9.57e-05], 0.001196], + # fmt: on + ] + ) + def test_stable_diffusion_encode_decode(self, seed, expected_slice, expected_mean_absolute_diff): + model = self.get_oobleck_vae_model() + audio = self.get_audio() + generator = self.get_generator(seed) + + with torch.no_grad(): + x = audio + posterior = model.encode(x).latent_dist + z = posterior.sample(generator=generator) + sample = model.decode(z).sample + + # (batch_size, latent_dim, sequence_length) + assert posterior.mean.shape == (audio.shape[0], model.config.decoder_input_channels, 1024) + + assert sample.shape == audio.shape + assert ((sample - audio).abs().mean() - expected_mean_absolute_diff).abs() <= 1e-6 + + output_slice = sample[-1, 1, 5:10].cpu() + expected_output_slice = torch.tensor(expected_slice) + + assert torch_all_close(output_slice, expected_output_slice, atol=1e-5) diff --git a/tests/pipelines/stable_audio/__init__.py b/tests/pipelines/stable_audio/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/pipelines/stable_audio/test_stable_audio.py b/tests/pipelines/stable_audio/test_stable_audio.py new file mode 100644 index 000000000000..d89bd70575c9 --- /dev/null +++ b/tests/pipelines/stable_audio/test_stable_audio.py @@ -0,0 +1,462 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import gc +import unittest + +import numpy as np +import torch +from transformers import ( + T5EncoderModel, + T5Tokenizer, +) + +from diffusers import ( + AutoencoderOobleck, + CosineDPMSolverMultistepScheduler, + StableAudioDiTModel, + StableAudioPipeline, + StableAudioProjectionModel, +) +from diffusers.utils import is_xformers_available +from diffusers.utils.testing_utils import enable_full_determinism, nightly, require_torch_gpu, torch_device + +from ..pipeline_params import TEXT_TO_AUDIO_BATCH_PARAMS +from ..test_pipelines_common import PipelineTesterMixin + + +enable_full_determinism() + + +class StableAudioPipelineFastTests(PipelineTesterMixin, unittest.TestCase): + pipeline_class = StableAudioPipeline + params = frozenset( + [ + "prompt", + "audio_end_in_s", + "audio_start_in_s", + "guidance_scale", + "negative_prompt", + "prompt_embeds", + "negative_prompt_embeds", + "initial_audio_waveforms", + ] + ) + batch_params = TEXT_TO_AUDIO_BATCH_PARAMS + required_optional_params = frozenset( + [ + "num_inference_steps", + "num_waveforms_per_prompt", + "generator", + "latents", + "output_type", + "return_dict", + "callback", + "callback_steps", + ] + ) + + def get_dummy_components(self): + torch.manual_seed(0) + transformer = StableAudioDiTModel( + sample_size=4, + in_channels=3, + num_layers=2, + attention_head_dim=4, + num_key_value_attention_heads=2, + out_channels=3, + cross_attention_dim=4, + time_proj_dim=8, + global_states_input_dim=8, + cross_attention_input_dim=4, + ) + scheduler = CosineDPMSolverMultistepScheduler( + solver_order=2, + prediction_type="v_prediction", + sigma_data=1.0, + sigma_schedule="exponential", + ) + torch.manual_seed(0) + vae = AutoencoderOobleck( + encoder_hidden_size=6, + downsampling_ratios=[1, 2], + decoder_channels=3, + decoder_input_channels=3, + audio_channels=2, + channel_multiples=[2, 4], + sampling_rate=4, + ) + torch.manual_seed(0) + t5_repo_id = "hf-internal-testing/tiny-random-T5ForConditionalGeneration" + text_encoder = T5EncoderModel.from_pretrained(t5_repo_id) + tokenizer = T5Tokenizer.from_pretrained(t5_repo_id, truncation=True, model_max_length=25) + + torch.manual_seed(0) + projection_model = StableAudioProjectionModel( + text_encoder_dim=text_encoder.config.d_model, + conditioning_dim=4, + min_value=0, + max_value=32, + ) + + components = { + "transformer": transformer, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "projection_model": projection_model, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + inputs = { + "prompt": "A hammer hitting a wooden surface", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + } + return inputs + + def test_save_load_local(self): + # increase tolerance from 1e-4 -> 7e-3 to account for large composite model + super().test_save_load_local(expected_max_difference=7e-3) + + def test_save_load_optional_components(self): + # increase tolerance from 1e-4 -> 7e-3 to account for large composite model + super().test_save_load_optional_components(expected_max_difference=7e-3) + + def test_stable_audio_ddim(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = stable_audio_pipe(**inputs) + audio = output.audios[0] + + assert audio.ndim == 2 + assert audio.shape == (2, 7) + + def test_stable_audio_without_prompts(self): + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = stable_audio_pipe(**inputs) + audio_1 = output.audios[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = stable_audio_pipe.tokenizer( + prompt, + padding="max_length", + max_length=stable_audio_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).to(torch_device) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + prompt_embeds = stable_audio_pipe.text_encoder( + text_input_ids, + attention_mask=attention_mask, + )[0] + + inputs["prompt_embeds"] = prompt_embeds + inputs["attention_mask"] = attention_mask + + # forward + output = stable_audio_pipe(**inputs) + audio_2 = output.audios[0] + + assert (audio_1 - audio_2).abs().max() < 1e-2 + + def test_stable_audio_negative_without_prompts(self): + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(torch_device) + negative_prompt = 3 * ["this is a negative prompt"] + inputs["negative_prompt"] = negative_prompt + inputs["prompt"] = 3 * [inputs["prompt"]] + + # forward + output = stable_audio_pipe(**inputs) + audio_1 = output.audios[0] + + inputs = self.get_dummy_inputs(torch_device) + prompt = 3 * [inputs.pop("prompt")] + + text_inputs = stable_audio_pipe.tokenizer( + prompt, + padding="max_length", + max_length=stable_audio_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).to(torch_device) + text_input_ids = text_inputs.input_ids + attention_mask = text_inputs.attention_mask + + prompt_embeds = stable_audio_pipe.text_encoder( + text_input_ids, + attention_mask=attention_mask, + )[0] + + inputs["prompt_embeds"] = prompt_embeds + inputs["attention_mask"] = attention_mask + + negative_text_inputs = stable_audio_pipe.tokenizer( + negative_prompt, + padding="max_length", + max_length=stable_audio_pipe.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ).to(torch_device) + negative_text_input_ids = negative_text_inputs.input_ids + negative_attention_mask = negative_text_inputs.attention_mask + + negative_prompt_embeds = stable_audio_pipe.text_encoder( + negative_text_input_ids, + attention_mask=negative_attention_mask, + )[0] + + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["negative_attention_mask"] = negative_attention_mask + + # forward + output = stable_audio_pipe(**inputs) + audio_2 = output.audios[0] + + assert (audio_1 - audio_2).abs().max() < 1e-2 + + def test_stable_audio_negative_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + negative_prompt = "egg cracking" + output = stable_audio_pipe(**inputs, negative_prompt=negative_prompt) + audio = output.audios[0] + + assert audio.ndim == 2 + assert audio.shape == (2, 7) + + def test_stable_audio_num_waveforms_per_prompt(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + prompt = "A hammer hitting a wooden surface" + + # test num_waveforms_per_prompt=1 (default) + audios = stable_audio_pipe(prompt, num_inference_steps=2).audios + + assert audios.shape == (1, 2, 7) + + # test num_waveforms_per_prompt=1 (default) for batch of prompts + batch_size = 2 + audios = stable_audio_pipe([prompt] * batch_size, num_inference_steps=2).audios + + assert audios.shape == (batch_size, 2, 7) + + # test num_waveforms_per_prompt for single prompt + num_waveforms_per_prompt = 2 + audios = stable_audio_pipe( + prompt, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt + ).audios + + assert audios.shape == (num_waveforms_per_prompt, 2, 7) + + # test num_waveforms_per_prompt for batch of prompts + batch_size = 2 + audios = stable_audio_pipe( + [prompt] * batch_size, num_inference_steps=2, num_waveforms_per_prompt=num_waveforms_per_prompt + ).audios + + assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7) + + def test_stable_audio_audio_end_in_s(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = stable_audio_pipe(audio_end_in_s=1.5, **inputs) + audio = output.audios[0] + + assert audio.ndim == 2 + assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.5 + + output = stable_audio_pipe(audio_end_in_s=1.1875, **inputs) + audio = output.audios[0] + + assert audio.ndim == 2 + assert audio.shape[1] / stable_audio_pipe.vae.sampling_rate == 1.0 + + def test_attention_slicing_forward_pass(self): + self._test_attention_slicing_forward_pass(test_mean_pixel_difference=False) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=5e-4) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(test_mean_pixel_difference=False) + + def test_stable_audio_input_waveform(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + stable_audio_pipe = StableAudioPipeline(**components) + stable_audio_pipe = stable_audio_pipe.to(device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + prompt = "A hammer hitting a wooden surface" + + initial_audio_waveforms = torch.ones((1, 5)) + + # test raises error when no sampling rate + with self.assertRaises(ValueError): + audios = stable_audio_pipe( + prompt, num_inference_steps=2, initial_audio_waveforms=initial_audio_waveforms + ).audios + + # test raises error when wrong sampling rate + with self.assertRaises(ValueError): + audios = stable_audio_pipe( + prompt, + num_inference_steps=2, + initial_audio_waveforms=initial_audio_waveforms, + initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate - 1, + ).audios + + audios = stable_audio_pipe( + prompt, + num_inference_steps=2, + initial_audio_waveforms=initial_audio_waveforms, + initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate, + ).audios + assert audios.shape == (1, 2, 7) + + # test works with num_waveforms_per_prompt + num_waveforms_per_prompt = 2 + audios = stable_audio_pipe( + prompt, + num_inference_steps=2, + num_waveforms_per_prompt=num_waveforms_per_prompt, + initial_audio_waveforms=initial_audio_waveforms, + initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate, + ).audios + + assert audios.shape == (num_waveforms_per_prompt, 2, 7) + + # test num_waveforms_per_prompt for batch of prompts and input audio (two channels) + batch_size = 2 + initial_audio_waveforms = torch.ones((batch_size, 2, 5)) + audios = stable_audio_pipe( + [prompt] * batch_size, + num_inference_steps=2, + num_waveforms_per_prompt=num_waveforms_per_prompt, + initial_audio_waveforms=initial_audio_waveforms, + initial_audio_sampling_rate=stable_audio_pipe.vae.sampling_rate, + ).audios + + assert audios.shape == (batch_size * num_waveforms_per_prompt, 2, 7) + + @unittest.skip("Not supported yet") + def test_sequential_cpu_offload_forward_pass(self): + pass + + @unittest.skip("Not supported yet") + def test_sequential_offload_forward_pass_twice(self): + pass + + +@nightly +@require_torch_gpu +class StableAudioPipelineIntegrationTests(unittest.TestCase): + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): + generator = torch.Generator(device=generator_device).manual_seed(seed) + latents = np.random.RandomState(seed).standard_normal((1, 64, 1024)) + latents = torch.from_numpy(latents).to(device=device, dtype=dtype) + inputs = { + "prompt": "A hammer hitting a wooden surface", + "latents": latents, + "generator": generator, + "num_inference_steps": 3, + "audio_end_in_s": 30, + "guidance_scale": 2.5, + } + return inputs + + def test_stable_audio(self): + stable_audio_pipe = StableAudioPipeline.from_pretrained( + "ylacombe/stable-audio-1.0" + ) # TODO (YL): change once changed + stable_audio_pipe = stable_audio_pipe.to(torch_device) + stable_audio_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_inputs(torch_device) + inputs["num_inference_steps"] = 25 + audio = stable_audio_pipe(**inputs).audios[0] + + assert audio.ndim == 2 + assert audio.shape == (2, int(inputs["audio_end_in_s"] * stable_audio_pipe.vae.sampling_rate)) + # check the portion of the generated audio with the largest dynamic range (reduces flakiness) + audio_slice = audio[0, 447590:447600] + # fmt: off + expected_slice = np.array( + [-0.0278, 0.1096, 0.1877, 0.3178, 0.5329, 0.6990, 0.6972, 0.6186, 0.5608, 0.5060] + ) + # fmt: one + max_diff = np.abs(expected_slice - audio_slice.detach().cpu().numpy()).max() + assert max_diff < 1.5e-3