Skip to content

[Core] add: controlnet support for SDXL #4038

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 54 commits into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
c6c9f3a
add: controlnet sdxl.
sayakpaul Jul 11, 2023
185b67b
modifications to controlnet.
sayakpaul Jul 11, 2023
db78a4c
run styling.
sayakpaul Jul 11, 2023
af8273a
add: __init__.pys
sayakpaul Jul 11, 2023
c8b00de
incorporate https://github.com/huggingface/diffusers/pull/4019 changes.
sayakpaul Jul 11, 2023
68f2c38
run make fix-copies.
sayakpaul Jul 11, 2023
4b86d63
resize the conditioning images.
sayakpaul Jul 12, 2023
dade681
remove autocast.
sayakpaul Jul 12, 2023
15d4afd
run styling.
sayakpaul Jul 12, 2023
5ed7d3e
disable autocast.
sayakpaul Jul 12, 2023
64b0e20
debugging
sayakpaul Jul 12, 2023
f482f46
device placement.
sayakpaul Jul 12, 2023
c2bbb2b
back to autocast.
sayakpaul Jul 12, 2023
ccb6210
remove comment.
sayakpaul Jul 12, 2023
be13ef5
save some memory by reusing the vae and unet in the pipeline.
sayakpaul Jul 12, 2023
fbb086a
apply styling.
sayakpaul Jul 12, 2023
65a5c45
Allow low precision sd xl
patrickvonplaten Jul 13, 2023
43f842c
finish
patrickvonplaten Jul 13, 2023
7171d42
finish
patrickvonplaten Jul 13, 2023
97f69a7
Merge branch 'main' into allow_low_precision_vae_sd_xl
patrickvonplaten Jul 13, 2023
d706b2b
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul Jul 14, 2023
e2ca903
Merge remote-tracking branch 'origin/allow_low_precision_vae_sd_xl' i…
sayakpaul Jul 14, 2023
fa22868
changes to accommodate the improved VAE.
sayakpaul Jul 14, 2023
b1c83fd
modifications to how we handle vae encoding in the training.
sayakpaul Jul 14, 2023
a4e10d8
make style
sayakpaul Jul 14, 2023
2d7c454
make existing controlnet fast tests pass.
sayakpaul Jul 14, 2023
7296a7f
change vae checkpoint cli arg.
sayakpaul Jul 14, 2023
74f7485
fix: vae pretrained paths.
sayakpaul Jul 14, 2023
73ccae8
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul Jul 16, 2023
7d27b1d
fix: steps in get_scheduler().
sayakpaul Jul 16, 2023
3fa4809
debugging.
sayakpaul Jul 17, 2023
115b30d
Merge branch 'feat/sd-xl-controlnet-2' of https://github.com/huggingf…
sayakpaul Jul 17, 2023
523c1eb
debugging./
sayakpaul Jul 17, 2023
98ed3c9
fix: weight conversion.
sayakpaul Jul 17, 2023
a397ead
add: docs.
sayakpaul Jul 17, 2023
4c14e88
add: limited tests./
sayakpaul Jul 17, 2023
156dd27
add: datasets to the requirements.
sayakpaul Jul 17, 2023
02a9d88
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul Jul 18, 2023
6611259
update docstrings and incorporate the usage of watermarking.
sayakpaul Jul 18, 2023
695c3ac
incorporate fix from #4083
sayakpaul Jul 18, 2023
680e241
fix watermarking dependency handling.
sayakpaul Jul 18, 2023
842f5dd
run make-fix-copies.
sayakpaul Jul 18, 2023
0c904ce
Empty-Commit
sayakpaul Jul 18, 2023
ae48efc
Update requirements_sdxl.txt
sayakpaul Jul 18, 2023
74dda65
remove vae upcasting part.
sayakpaul Jul 18, 2023
08153a9
Apply suggestions from code review
sayakpaul Jul 18, 2023
dc4839e
run make style
sayakpaul Jul 18, 2023
54fdb56
run make fix-copies.
sayakpaul Jul 18, 2023
5d53cb8
disable suppot for multicontrolnet.
sayakpaul Jul 18, 2023
e430d1a
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul Jul 18, 2023
740944a
Apply suggestions from code review
sayakpaul Jul 18, 2023
e35cf2b
run make fix-copies.
sayakpaul Jul 18, 2023
d7aecd2
dtyle/.
sayakpaul Jul 18, 2023
f129bc4
fix-copies.
sayakpaul Jul 18, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions examples/controlnet/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed
# remove following line if xformers is not installed or when using Torch 2.0.
pipe.enable_xformers_memory_efficient_attention()

# memory optimization.
pipe.enable_model_cpu_offload()

control_image = load_image("./conditioning_image_1.png")
Expand All @@ -285,9 +285,8 @@ prompt = "pale golden rod circle with old lace background"
# generate image
generator = torch.manual_seed(0)
image = pipe(
prompt, num_inference_steps=20, generator=generator, image=control_image
prompt, num_inference_steps=20, generator=generator, image=control_image
).images[0]

image.save("./output.png")
```

Expand Down Expand Up @@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).

Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).

## Support for Stable Diffusion XL

We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.
131 changes: 131 additions & 0 deletions examples/controlnet/README_sdxl.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
# DreamBooth training example for Stable Diffusion XL (SDXL)

The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).

## Running locally with PyTorch

### Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

**Important**

To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```

Then cd in the `examples/controlnet` folder and run
```bash
pip install -r requirements_sdxl.txt
```

And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:

```bash
accelerate config
```

Or for a default accelerate configuration without answering questions about your environment

```bash
accelerate config default
```

Or if your environment doesn't support an interactive shell (e.g., a notebook)

```python
from accelerate.utils import write_basic_config
write_basic_config()
```

When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.

## Circle filling dataset

The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.

## Training

Our training examples use two test conditioning images. They can be downloaded by running

```sh
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png

wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
```

Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.

```bash
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
export OUTPUT_DIR="path to save model"

accelerate launch train_controlnet_sdxl.py \
--pretrained_model_name_or_path=$MODEL_DIR \
--output_dir=$OUTPUT_DIR \
--dataset_name=fusing/fill50k \
--mixed_precision="fp16" \
--resolution=1024 \
--learning_rate=1e-5 \
--max_train_steps=15000 \
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
--validation_steps=100 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--report_to="wandb" \
--seed=42 \
--push_to_hub
```

To better track our training experiments, we're using the following flags in the command above:

* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.

Our experiments were conducted on a single 40GB A100 GPU.

### Inference

Once training is done, we can perform inference like so:

```python
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from diffusers.utils import load_image
import torch

base_model_path = "stabilityai/stable-diffusion-xl-base-0.9"
controlnet_path = "path to controlnet"

controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
)

# speed up diffusion process with faster scheduler and memory optimization
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# remove following line if xformers is not installed or when using Torch 2.0.
pipe.enable_xformers_memory_efficient_attention()
# memory optimization.
pipe.enable_model_cpu_offload()

control_image = load_image("./conditioning_image_1.png")
prompt = "pale golden rod circle with old lace background"

# generate image
generator = torch.manual_seed(0)
image = pipe(
prompt, num_inference_steps=20, generator=generator, image=control_image
).images[0]
image.save("./output.png")
```

## Notes

### Specifying a better VAE

SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
8 changes: 8 additions & 0 deletions examples/controlnet/requirements_sdxl.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
invisible-watermark>=0.2.0
datasets
Loading