diff --git a/examples/controlnet/README.md b/examples/controlnet/README.md index 571e9e708cf2..15b0170d5120 100644 --- a/examples/controlnet/README.md +++ b/examples/controlnet/README.md @@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained( # speed up diffusion process with faster scheduler and memory optimization pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) -# remove following line if xformers is not installed +# remove following line if xformers is not installed or when using Torch 2.0. pipe.enable_xformers_memory_efficient_attention() - +# memory optimization. pipe.enable_model_cpu_offload() control_image = load_image("./conditioning_image_1.png") @@ -285,9 +285,8 @@ prompt = "pale golden rod circle with old lace background" # generate image generator = torch.manual_seed(0) image = pipe( - prompt, num_inference_steps=20, generator=generator, image=control_image + prompt, num_inference_steps=20, generator=generator, image=control_image ).images[0] - image.save("./output.png") ``` @@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`). Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident). + +## Support for Stable Diffusion XL + +We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details. diff --git a/examples/controlnet/README_sdxl.md b/examples/controlnet/README_sdxl.md new file mode 100644 index 000000000000..4e4035066f8c --- /dev/null +++ b/examples/controlnet/README_sdxl.md @@ -0,0 +1,131 @@ +# DreamBooth training example for Stable Diffusion XL (SDXL) + +The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/controlnet` folder and run +```bash +pip install -r requirements_sdxl.txt +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + +## Circle filling dataset + +The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. + +## Training + +Our training examples use two test conditioning images. They can be downloaded by running + +```sh +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png + +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub. + +```bash +export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" +export OUTPUT_DIR="path to save model" + +accelerate launch train_controlnet_sdxl.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --dataset_name=fusing/fill50k \ + --mixed_precision="fp16" \ + --resolution=1024 \ + --learning_rate=1e-5 \ + --max_train_steps=15000 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --validation_steps=100 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 \ + --report_to="wandb" \ + --seed=42 \ + --push_to_hub +``` + +To better track our training experiments, we're using the following flags in the command above: + +* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. +* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + +### Inference + +Once training is done, we can perform inference like so: + +```python +from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler +from diffusers.utils import load_image +import torch + +base_model_path = "stabilityai/stable-diffusion-xl-base-0.9" +controlnet_path = "path to controlnet" + +controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) +pipe = StableDiffusionXLControlNetPipeline.from_pretrained( + base_model_path, controlnet=controlnet, torch_dtype=torch.float16 +) + +# speed up diffusion process with faster scheduler and memory optimization +pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) +# remove following line if xformers is not installed or when using Torch 2.0. +pipe.enable_xformers_memory_efficient_attention() +# memory optimization. +pipe.enable_model_cpu_offload() + +control_image = load_image("./conditioning_image_1.png") +prompt = "pale golden rod circle with old lace background" + +# generate image +generator = torch.manual_seed(0) +image = pipe( + prompt, num_inference_steps=20, generator=generator, image=control_image +).images[0] +image.save("./output.png") +``` + +## Notes + +### Specifying a better VAE + +SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). \ No newline at end of file diff --git a/examples/controlnet/requirements_sdxl.txt b/examples/controlnet/requirements_sdxl.txt new file mode 100644 index 000000000000..0192e03ddb4c --- /dev/null +++ b/examples/controlnet/requirements_sdxl.txt @@ -0,0 +1,9 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +invisible-watermark>=0.2.0 +datasets +wandb diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py new file mode 100644 index 000000000000..5f5a35a4e37b --- /dev/null +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -0,0 +1,1246 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2023 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import functools +import gc +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import AutoTokenizer, PretrainedConfig + +import diffusers +from diffusers import ( + AutoencoderKL, + ControlNetModel, + DDPMScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, + UniPCMultistepScheduler, +) +from diffusers.optimization import get_scheduler +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.import_utils import is_xformers_available + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.18.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(vae, unet, controlnet, args, accelerator, weight_dtype, step): + logger.info("Running validation... ") + + controlnet = accelerator.unwrap_model(controlnet) + + pipeline = StableDiffusionXLControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + vae=vae, + unet=unet, + controlnet=controlnet, + safety_checker=None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = UniPCMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) + + if args.enable_xformers_memory_efficient_attention: + pipeline.enable_xformers_memory_efficient_attention() + + if args.seed is None: + generator = None + else: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + validation_image = validation_image.resize((args.resolution, args.resolution)) + + images = [] + + for _ in range(args.num_validation_images): + with torch.autocast("cuda"): + image = pipeline( + validation_prompt, validation_image, num_inference_steps=20, generator=generator + ).images[0] + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images = [] + + formatted_images.append(np.asarray(validation_image)) + + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({"validation": formatted_images}) + else: + logger.warn(f"image logging not implemented for {tracker.name}") + + del pipeline + gc.collect() + torch.cuda.empty_cache() + + return image_logs + + +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision, use_auth_token=True + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion-xl +- stable-diffusion-xl-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--pretrained_vae_model_name_or_path", + type=str, + default=None, + help="Path to an improved VAE to stabilize training. For more details check out: https://github.com/huggingface/diffusers/pull/4038.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be" + " float32 precision." + ), + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--crops_coords_top_left_h", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--crops_coords_top_left_w", + type=int, + default=0, + help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that 🤗 Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="sd_xl_train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def get_train_dataset(args, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + with accelerator.main_process_first(): + train_dataset = dataset["train"].shuffle(seed=args.seed) + if args.max_train_samples is not None: + train_dataset = train_dataset.select(range(args.max_train_samples)) + return train_dataset + + +# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt +def encode_prompt(prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train=True): + prompt_embeds_list = [] + + captions = [] + for caption in prompt_batch: + if random.random() < proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + + with torch.no_grad(): + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + captions, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder( + text_input_ids.to(text_encoder.device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + bs_embed, seq_len, _ = prompt_embeds.shape + prompt_embeds = prompt_embeds.view(bs_embed, seq_len, -1) + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1) + return prompt_embeds, pooled_prompt_embeds + + +def prepare_train_dataset(dataset, accelerator): + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[args.image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[args.conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + + return examples + + with accelerator.main_process_first(): + dataset = dataset.with_transform(preprocess_train) + + return dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_ids = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + + add_text_embeds = torch.stack([torch.tensor(example["text_embeds"]) for example in examples]) + add_time_ids = torch.stack([torch.tensor(example["time_ids"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_ids": prompt_ids, + "unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids}, + } + + +def main(args): + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, + exist_ok=True, + token=args.hub_token, + private=True, + ).repo_id + + # Load the tokenizers + tokenizer_one = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision, use_fast=False + ) + tokenizer_two = AutoTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer_2", revision=args.revision, use_fast=False + ) + + # import correct text encoder classes + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + + # Load scheduler and models + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + text_encoder_one = text_encoder_cls_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision + ) + text_encoder_two = text_encoder_cls_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision + ) + vae_path = ( + args.pretrained_model_name_or_path + if args.pretrained_vae_model_name_or_path is None + else args.pretrained_vae_model_name_or_path + ) + vae = AutoencoderKL.from_pretrained( + vae_path, + subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, + revision=args.revision, + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision, use_auth_token=True + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from unet") + controlnet = ControlNetModel.from_unet(unet) + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + vae.requires_grad_(False) + unet.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + controlnet.train() + + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + import xformers + + xformers_version = version.parse(xformers.__version__) + if xformers_version == version.parse("0.0.16"): + logger.warn( + "xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details." + ) + unet.enable_xformers_memory_efficient_attention() + controlnet.enable_xformers_memory_efficient_attention() + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if accelerator.unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {accelerator.unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, unet and text_encoder to device and cast to weight_dtype + # The VAE is in float32 to avoid NaN losses. + if args.pretrained_vae_model_name_or_path is not None: + vae.to(accelerator.device, dtype=weight_dtype) + else: + vae.to(accelerator.device, dtype=torch.float32) + unet.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + + # Here, we compute not just the text embeddings but also the additional embeddings + # needed for the SD XL UNet to operate. + def compute_embeddings(batch, proportion_empty_prompts, text_encoders, tokenizers, is_train=True): + original_size = (args.resolution, args.resolution) + target_size = (args.resolution, args.resolution) + crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) + prompt_batch = batch[args.caption_column] + + prompt_embeds, pooled_prompt_embeds = encode_prompt( + prompt_batch, text_encoders, tokenizers, proportion_empty_prompts, is_train + ) + add_text_embeds = pooled_prompt_embeds + + # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids + add_time_ids = list(original_size + crops_coords_top_left + target_size) + add_time_ids = torch.tensor([add_time_ids]) + + prompt_embeds = prompt_embeds.to(accelerator.device) + add_text_embeds = add_text_embeds.to(accelerator.device) + add_time_ids = add_time_ids.repeat(len(prompt_batch), 1) + add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype) + unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + + return {"prompt_embeds": prompt_embeds, **unet_added_cond_kwargs} + + # Let's first compute all the embeddings so that we can free up the text encoders + # from memory. + text_encoders = [text_encoder_one, text_encoder_two] + tokenizers = [tokenizer_one, tokenizer_two] + train_dataset = get_train_dataset(args, accelerator) + compute_embeddings_fn = functools.partial( + compute_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + proportion_empty_prompts=args.proportion_empty_prompts, + ) + with accelerator.main_process_first(): + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True) + + del text_encoders, tokenizers + gc.collect() + torch.cuda.empty_cache() + + # Then get the training dataset ready to be passed to the dataloader. + train_dataset = prepare_train_dataset(train_dataset, accelerator) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + if args.pretrained_vae_model_name_or_path is not None: + pixel_values = batch["pixel_values"].to(dtype=weight_dtype) + if vae.dtype != weight_dtype: + vae.to(dtype=weight_dtype) + else: + pixel_values = batch["pixel_values"] + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * vae.config.scaling_factor + if args.pretrained_vae_model_name_or_path is None: + latents = latents.to(weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # ControlNet conditioning. + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + down_block_res_samples, mid_block_res_sample = controlnet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + controlnet_cond=controlnet_image, + return_dict=False, + ) + + # Predict the noise residual + model_pred = unet( + noisy_latents, + timesteps, + encoder_hidden_states=batch["prompt_ids"], + added_cond_kwargs=batch["unet_added_conditions"], + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), + ).sample + + # Get the target for loss depending on the prediction type + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + vae, unet, controlnet, args, accelerator, weight_dtype, global_step + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = accelerator.unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 0c0869fb52bd..445e63b4d406 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -199,6 +199,7 @@ from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: from .pipelines import ( + StableDiffusionXLControlNetPipeline, StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, StableDiffusionXLPipeline, diff --git a/src/diffusers/models/controlnet.py b/src/diffusers/models/controlnet.py index b0f566020079..acebee4b40a5 100644 --- a/src/diffusers/models/controlnet.py +++ b/src/diffusers/models/controlnet.py @@ -21,7 +21,7 @@ from ..configuration_utils import ConfigMixin, register_to_config from ..utils import BaseOutput, logging from .attention_processor import AttentionProcessor, AttnProcessor -from .embeddings import TimestepEmbedding, Timesteps +from .embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps from .modeling_utils import ModelMixin from .unet_2d_blocks import ( CrossAttnDownBlock2D, @@ -131,12 +131,25 @@ class ControlNetModel(ModelMixin, ConfigMixin): The epsilon to use for the normalization. cross_attention_dim (`int`, defaults to 1280): The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): The dimension of the attention heads. use_linear_projection (`bool`, defaults to `False`): class_embed_type (`str`, *optional*, defaults to `None`): The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. num_class_embeds (`int`, *optional*, defaults to 0): Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing class conditioning with `class_embed_type` equal to `None`. @@ -177,10 +190,15 @@ def __init__( norm_num_groups: Optional[int] = 32, norm_eps: float = 1e-5, cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, attention_head_dim: Union[int, Tuple[int]] = 8, num_attention_heads: Optional[Union[int, Tuple[int]]] = None, use_linear_projection: bool = False, class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, num_class_embeds: Optional[int] = None, upcast_attention: bool = False, resnet_time_scale_shift: str = "default", @@ -188,6 +206,7 @@ def __init__( controlnet_conditioning_channel_order: str = "rgb", conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), global_pool_conditions: bool = False, + addition_embed_type_num_heads=64, ): super().__init__() @@ -215,6 +234,9 @@ def __init__( f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." ) + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + # input conv_in_kernel = 3 conv_in_padding = (conv_in_kernel - 1) // 2 @@ -224,16 +246,43 @@ def __init__( # time time_embed_dim = block_out_channels[0] * 4 - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) timestep_input_dim = block_out_channels[0] - self.time_embedding = TimestepEmbedding( timestep_input_dim, time_embed_dim, act_fn=act_fn, ) + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + # class embedding if class_embed_type is None and num_class_embeds is not None: self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) @@ -257,6 +306,29 @@ def __init__( else: self.class_embedding = None + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + # control net conditioning embedding self.controlnet_cond_embedding = ControlNetConditioningEmbedding( conditioning_embedding_channels=block_out_channels[0], @@ -291,6 +363,7 @@ def __init__( down_block = get_down_block( down_block_type, num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], in_channels=input_channel, out_channels=output_channel, temb_channels=time_embed_dim, @@ -327,6 +400,7 @@ def __init__( self.controlnet_mid_block = controlnet_block self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], in_channels=mid_block_channel, temb_channels=time_embed_dim, resnet_eps=norm_eps, @@ -356,7 +430,22 @@ def from_unet( The UNet model weights to copy to the [`ControlNetModel`]. All configuration options are also copied where applicable. """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + controlnet = cls( + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + transformer_layers_per_block=transformer_layers_per_block, in_channels=unet.config.in_channels, flip_sin_to_cos=unet.config.flip_sin_to_cos, freq_shift=unet.config.freq_shift, @@ -542,6 +631,7 @@ def forward( class_labels: Optional[torch.Tensor] = None, timestep_cond: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guess_mode: bool = False, return_dict: bool = True, @@ -564,7 +654,9 @@ def forward( Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - cross_attention_kwargs(`dict[str]`, *optional*, defaults to `None`): + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): A kwargs dictionary that if specified is passed along to the `AttnProcessor`. guess_mode (`bool`, defaults to `False`): In this mode, the ControlNet encoder tries its best to recognize the input content of the input even if @@ -618,6 +710,7 @@ def forward( t_emb = t_emb.to(dtype=sample.dtype) emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None if self.class_embedding is not None: if class_labels is None: @@ -629,6 +722,30 @@ def forward( class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) emb = emb + class_emb + if "addition_embed_type" in self.config: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/__init__.py b/src/diffusers/pipelines/__init__.py index aa09e7e81130..802ae4f5bc94 100644 --- a/src/diffusers/pipelines/__init__.py +++ b/src/diffusers/pipelines/__init__.py @@ -120,6 +120,7 @@ except OptionalDependencyNotAvailable: from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403 else: + from .controlnet import StableDiffusionXLControlNetPipeline from .stable_diffusion_xl import ( StableDiffusionXLImg2ImgPipeline, StableDiffusionXLInpaintPipeline, diff --git a/src/diffusers/pipelines/controlnet/__init__.py b/src/diffusers/pipelines/controlnet/__init__.py index 76ab63bdb116..cd0b9a724e5e 100644 --- a/src/diffusers/pipelines/controlnet/__init__.py +++ b/src/diffusers/pipelines/controlnet/__init__.py @@ -1,11 +1,16 @@ from ...utils import ( OptionalDependencyNotAvailable, is_flax_available, + is_invisible_watermark_available, is_torch_available, is_transformers_available, ) +if is_transformers_available() and is_torch_available() and is_invisible_watermark_available(): + from .pipeline_controlnet_sd_xl import StableDiffusionXLControlNetPipeline + + try: if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() diff --git a/src/diffusers/pipelines/controlnet/multicontrolnet.py b/src/diffusers/pipelines/controlnet/multicontrolnet.py index 921895b8fd92..2df8c7edd7ae 100644 --- a/src/diffusers/pipelines/controlnet/multicontrolnet.py +++ b/src/diffusers/pipelines/controlnet/multicontrolnet.py @@ -45,17 +45,17 @@ def forward( ) -> Union[ControlNetOutput, Tuple]: for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)): down_samples, mid_sample = controlnet( - sample, - timestep, - encoder_hidden_states, - image, - scale, - class_labels, - timestep_cond, - attention_mask, - cross_attention_kwargs, - guess_mode, - return_dict, + sample=sample, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + controlnet_cond=image, + conditioning_scale=scale, + class_labels=class_labels, + timestep_cond=timestep_cond, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + guess_mode=guess_mode, + return_dict=return_dict, ) # merge samples diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py new file mode 100644 index 000000000000..fd4338946e6f --- /dev/null +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -0,0 +1,960 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import PIL.Image +import torch +import torch.nn.functional as F +from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from ...image_processor import VaeImageProcessor +from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin +from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel +from ...models.attention_processor import ( + AttnProcessor2_0, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + XFormersAttnProcessor, +) +from ...schedulers import KarrasDiffusionSchedulers +from ...utils import ( + is_accelerate_available, + is_accelerate_version, + is_compiled_module, + logging, + randn_tensor, + replace_example_docstring, +) +from ..pipeline_utils import DiffusionPipeline +from ..stable_diffusion_xl import StableDiffusionXLPipelineOutput +from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker +from .multicontrolnet import MultiControlNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> # To be updated when there's a useful ControlNet checkpoint + >>> # compatible with SDXL. + ``` +""" + + +class StableDiffusionXLControlNetPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): + r""" + Pipeline for text-to-image generation using Stable Diffusion XL with ControlNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + In addition the pipeline inherits the following loading methods: + - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`] + - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`] + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + text_encoder_2 ([` CLIPTextModelWithProjection`]): + Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection), + specifically the + [laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k) + variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + tokenizer_2 (`CLIPTokenizer`): + Second Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + controlnet ([`ControlNetModel`] or `List[ControlNetModel]`): + Provides additional conditioning to the unet during the denoising process. If you set multiple ControlNets + as a list, the outputs from each ControlNet are added together to create one combined additional + conditioning. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + """ + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + text_encoder_2: CLIPTextModelWithProjection, + tokenizer: CLIPTokenizer, + tokenizer_2: CLIPTokenizer, + unet: UNet2DConditionModel, + controlnet: ControlNetModel, + scheduler: KarrasDiffusionSchedulers, + force_zeros_for_empty_prompt: bool = True, + ): + super().__init__() + + if isinstance(controlnet, (list, tuple)): + raise ValueError("MultiControlNet is not yet supported.") + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + text_encoder_2=text_encoder_2, + tokenizer=tokenizer, + tokenizer_2=tokenizer_2, + unet=unet, + controlnet=controlnet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.control_image_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False + ) + self.watermark = StableDiffusionXLWatermarker() + self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # control net hook has be manually offloaded as it alternates with unet + cpu_offload_with_hook(self.controlnet, device) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + """ + device = device or self._execution_device + + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # Define tokenizers and text encoders + tokenizers = [self.tokenizer, self.tokenizer_2] if self.tokenizer is not None else [self.tokenizer_2] + text_encoders = ( + [self.text_encoder, self.text_encoder_2] if self.text_encoder is not None else [self.text_encoder_2] + ) + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, tokenizer) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1]) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {tokenizer.model_max_length} tokens: {removed_text}" + ) + + prompt_embeds = text_encoder( + text_input_ids.to(device), + output_hidden_states=True, + ) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = negative_prompt is None and self.config.force_zeros_for_empty_prompt + if do_classifier_free_guidance and negative_prompt_embeds is None and zero_out_negative_prompt: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + uncond_tokens: List[str] + if prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder( + uncond_input.input_ids.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed = pooled_prompt_embeds.shape[0] + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + image, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + controlnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.controlnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + else: + assert False + + # Check `controlnet_conditioning_scale` + if ( + isinstance(self.controlnet, ControlNetModel) + or is_compiled + and isinstance(self.controlnet._orig_mod, ControlNetModel) + ): + if not isinstance(controlnet_conditioning_scale, float): + raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.") + else: + assert False + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + image, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + if do_classifier_free_guidance and not guess_mode: + image = torch.cat([image] * 2) + + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) + + passed_add_embed_dim = ( + self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + ) + expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features + + if expected_add_embed_dim != passed_add_embed_dim: + raise ValueError( + f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." + ) + + add_time_ids = torch.tensor([add_time_ids], dtype=dtype) + return add_time_ids + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_upscale.StableDiffusionUpscalePipeline.upcast_vae + def upcast_vae(self): + dtype = self.vae.dtype + self.vae.to(dtype=torch.float32) + use_torch_2_0_or_xformers = isinstance( + self.vae.decoder.mid_block.attentions[0].processor, + ( + AttnProcessor2_0, + XFormersAttnProcessor, + LoRAXFormersAttnProcessor, + LoRAAttnProcessor2_0, + ), + ) + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if use_torch_2_0_or_xformers: + self.vae.post_quant_conv.to(dtype) + self.vae.decoder.conv_in.to(dtype) + self.vae.decoder.mid_block.to(dtype) + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[ + torch.FloatTensor, + PIL.Image.Image, + np.ndarray, + List[torch.FloatTensor], + List[PIL.Image.Image], + List[np.ndarray], + ] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + controlnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + original_size: Tuple[int, int] = (1024, 1024), + crops_coords_top_left: Tuple[int, int] = (0, 0), + target_size: Tuple[int, int] = (1024, 1024), + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If + the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can + also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If + height and/or width are passed, `image` is resized according to them. If multiple ControlNets are + specified in init, images must be passed as a list such that each element of the list can be correctly + batched for input to a single controlnet. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added + to the residual in the original unet. If multiple ControlNets are specified in init, you can set the + corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + In this mode, the ControlNet encoder will try best to recognize the content of the input image even if + you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the controlnet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the controlnet stops applying. + original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): + TODO + target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): + TODO + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple` + containing the output images. + """ + controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1 + control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [ + control_guidance_end + ] + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + image, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + controlnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + global_pool_conditions = ( + controlnet.config.global_pool_conditions + if isinstance(controlnet, ControlNetModel) + else controlnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + + # 3. Encode input prompt + text_encoder_lora_scale = ( + cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None + ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + ) + + # 4. Prepare image + if isinstance(controlnet, ControlNetModel): + image = self.prepare_image( + image=image, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=controlnet.dtype, + do_classifier_free_guidance=do_classifier_free_guidance, + guess_mode=guess_mode, + ) + height, width = image.shape[-2:] + else: + assert False + + # 5. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Create tensor stating which controlnets to keep + controlnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + controlnet_keep.append(keeps[0] if len(keeps) == 1 else keeps) + + # 7.2 Prepare added time ids & embeddings + add_text_embeds = pooled_prompt_embeds + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + ) + + if do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) + add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) + + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # controlnet(s) inference + if guess_mode and do_classifier_free_guidance: + # Infer ControlNet only for the conditional batch. + control_model_input = latents + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + controlnet_prompt_embeds = prompt_embeds.chunk(2)[1] + else: + control_model_input = latent_model_input + controlnet_prompt_embeds = prompt_embeds + + if isinstance(controlnet_keep[i], list): + cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])] + else: + cond_scale = controlnet_conditioning_scale * controlnet_keep[i] + + added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + down_block_res_samples, mid_block_res_sample = self.controlnet( + control_model_input, + t, + encoder_hidden_states=controlnet_prompt_embeds, + controlnet_cond=image, + conditioning_scale=cond_scale, + guess_mode=guess_mode, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + ) + + if guess_mode and do_classifier_free_guidance: + # Infered ControlNet only for the conditional batch. + # To apply the output of ControlNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + cross_attention_kwargs=cross_attention_kwargs, + down_block_additional_residuals=down_block_res_samples, + mid_block_additional_residual=mid_block_res_sample, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # If we do sequential model offloading, let's offload unet and controlnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.controlnet.to("cpu") + torch.cuda.empty_cache() + + # make sure the VAE is in float32 mode, as it overflows in float16 + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + else: + image = latents + return StableDiffusionXLPipelineOutput(images=image) + + image = self.watermark.apply_watermark(image) + image = self.image_processor.postprocess(image, output_type=output_type) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image,) + + return StableDiffusionXLPipelineOutput(images=image) diff --git a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py index 15ef1989a58b..3b4d59c537fa 100644 --- a/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +++ b/src/diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py @@ -2,6 +2,21 @@ from ..utils import DummyObject, requires_backends +class StableDiffusionXLControlNetPipeline(metaclass=DummyObject): + _backends = ["torch", "transformers", "invisible_watermark"] + + def __init__(self, *args, **kwargs): + requires_backends(self, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_config(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + @classmethod + def from_pretrained(cls, *args, **kwargs): + requires_backends(cls, ["torch", "transformers", "invisible_watermark"]) + + class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject): _backends = ["torch", "transformers", "invisible_watermark"] diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py new file mode 100644 index 000000000000..196bdc303a49 --- /dev/null +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -0,0 +1,179 @@ +# coding=utf-8 +# Copyright 2023 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import torch +from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + ControlNetModel, + EulerDiscreteScheduler, + StableDiffusionXLControlNetPipeline, + UNet2DConditionModel, +) +from diffusers.utils import randn_tensor, torch_device +from diffusers.utils.import_utils import is_xformers_available +from diffusers.utils.testing_utils import ( + enable_full_determinism, +) + +from ..pipeline_params import ( + IMAGE_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_BATCH_PARAMS, + TEXT_TO_IMAGE_IMAGE_PARAMS, + TEXT_TO_IMAGE_PARAMS, +) +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) + + +enable_full_determinism() + + +class ControlNetPipelineSDXLFastTests( + PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase +): + pipeline_class = StableDiffusionXLControlNetPipeline + params = TEXT_TO_IMAGE_PARAMS + batch_params = TEXT_TO_IMAGE_BATCH_PARAMS + image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS + image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS + + def get_dummy_components(self): + torch.manual_seed(0) + unet = UNet2DConditionModel( + block_out_channels=(32, 64), + layers_per_block=2, + sample_size=32, + in_channels=4, + out_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + controlnet = ControlNetModel( + block_out_channels=(32, 64), + layers_per_block=2, + in_channels=4, + down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"), + conditioning_embedding_out_channels=(16, 32), + # SD2-specific config below + attention_head_dim=(2, 4), + use_linear_projection=True, + addition_embed_type="text_time", + addition_time_embed_dim=8, + transformer_layers_per_block=(1, 2), + projection_class_embeddings_input_dim=80, # 6 * 8 + 32 + cross_attention_dim=64, + ) + torch.manual_seed(0) + scheduler = EulerDiscreteScheduler( + beta_start=0.00085, + beta_end=0.012, + steps_offset=1, + beta_schedule="scaled_linear", + timestep_spacing="leading", + ) + torch.manual_seed(0) + vae = AutoencoderKL( + block_out_channels=[32, 64], + in_channels=3, + out_channels=3, + down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], + up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], + latent_channels=4, + ) + torch.manual_seed(0) + text_encoder_config = CLIPTextConfig( + bos_token_id=0, + eos_token_id=2, + hidden_size=32, + intermediate_size=37, + layer_norm_eps=1e-05, + num_attention_heads=4, + num_hidden_layers=5, + pad_token_id=1, + vocab_size=1000, + # SD2-specific config below + hidden_act="gelu", + projection_dim=32, + ) + text_encoder = CLIPTextModel(text_encoder_config) + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config) + tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True) + + components = { + "unet": unet, + "controlnet": controlnet, + "scheduler": scheduler, + "vae": vae, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "text_encoder_2": text_encoder_2, + "tokenizer_2": tokenizer_2, + } + return components + + def get_dummy_inputs(self, device, seed=0): + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device=device).manual_seed(seed) + + controlnet_embedder_scale_factor = 2 + image = randn_tensor( + (1, 3, 32 * controlnet_embedder_scale_factor, 32 * controlnet_embedder_scale_factor), + generator=generator, + device=torch.device(device), + ) + + inputs = { + "prompt": "A painting of a squirrel eating a burger", + "generator": generator, + "num_inference_steps": 2, + "guidance_scale": 6.0, + "output_type": "numpy", + "image": image, + } + + return inputs + + def test_attention_slicing_forward_pass(self): + return self._test_attention_slicing_forward_pass(expected_max_diff=2e-3) + + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_attention_forwardGenerator_pass(self): + self._test_xformers_attention_forwardGenerator_pass(expected_max_diff=2e-3) + + def test_inference_batch_single_identical(self): + self._test_inference_batch_single_identical(expected_max_diff=2e-3)