Skip to content

Commit 3eb498e

Browse files
[Core] add: controlnet support for SDXL (#4038)
* add: controlnet sdxl. * modifications to controlnet. * run styling. * add: __init__.pys * incorporate #4019 changes. * run make fix-copies. * resize the conditioning images. * remove autocast. * run styling. * disable autocast. * debugging * device placement. * back to autocast. * remove comment. * save some memory by reusing the vae and unet in the pipeline. * apply styling. * Allow low precision sd xl * finish * finish * changes to accommodate the improved VAE. * modifications to how we handle vae encoding in the training. * make style * make existing controlnet fast tests pass. * change vae checkpoint cli arg. * fix: vae pretrained paths. * fix: steps in get_scheduler(). * debugging. * debugging./ * fix: weight conversion. * add: docs. * add: limited tests./ * add: datasets to the requirements. * update docstrings and incorporate the usage of watermarking. * incorporate fix from #4083 * fix watermarking dependency handling. * run make-fix-copies. * Empty-Commit * Update requirements_sdxl.txt * remove vae upcasting part. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * run make style * run make fix-copies. * disable suppot for multicontrolnet. * Apply suggestions from code review Co-authored-by: Patrick von Platen <[email protected]> * run make fix-copies. * dtyle/. * fix-copies. --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent c6e56e9 commit 3eb498e

12 files changed

+2686
-19
lines changed

examples/controlnet/README.md

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,9 +274,9 @@ pipe = StableDiffusionControlNetPipeline.from_pretrained(
274274
275275
# speed up diffusion process with faster scheduler and memory optimization
276276
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
277-
# remove following line if xformers is not installed
277+
# remove following line if xformers is not installed or when using Torch 2.0.
278278
pipe.enable_xformers_memory_efficient_attention()
279-
279+
# memory optimization.
280280
pipe.enable_model_cpu_offload()
281281
282282
control_image = load_image("./conditioning_image_1.png")
@@ -285,9 +285,8 @@ prompt = "pale golden rod circle with old lace background"
285285
# generate image
286286
generator = torch.manual_seed(0)
287287
image = pipe(
288-
prompt, num_inference_steps=20, generator=generator, image=control_image
288+
prompt, num_inference_steps=20, generator=generator, image=control_image
289289
).images[0]
290-
291290
image.save("./output.png")
292291
```
293292

@@ -460,3 +459,7 @@ The profile can then be inspected at http://localhost:6006/#profile
460459
Sometimes you'll get version conflicts (error messages like `Duplicate plugins for name projector`), which means that you have to uninstall and reinstall all versions of Tensorflow/Tensorboard (e.g. with `pip uninstall tensorflow tf-nightly tensorboard tb-nightly tensorboard-plugin-profile && pip install tf-nightly tbp-nightly tensorboard-plugin-profile`).
461460

462461
Note that the debugging functionality of the Tensorboard `profile` plugin is still under active development. Not all views are fully functional, and for example the `trace_viewer` cuts off events after 1M (which can result in all your device traces getting lost if you for example profile the compilation step by accident).
462+
463+
## Support for Stable Diffusion XL
464+
465+
We provide a training script for training a ControlNet with [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). Please refer to [README_sdxl.md](./README_sdxl.md) for more details.

examples/controlnet/README_sdxl.md

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
# DreamBooth training example for Stable Diffusion XL (SDXL)
2+
3+
The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
4+
5+
## Running locally with PyTorch
6+
7+
### Installing the dependencies
8+
9+
Before running the scripts, make sure to install the library's training dependencies:
10+
11+
**Important**
12+
13+
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
14+
15+
```bash
16+
git clone https://github.com/huggingface/diffusers
17+
cd diffusers
18+
pip install -e .
19+
```
20+
21+
Then cd in the `examples/controlnet` folder and run
22+
```bash
23+
pip install -r requirements_sdxl.txt
24+
```
25+
26+
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
27+
28+
```bash
29+
accelerate config
30+
```
31+
32+
Or for a default accelerate configuration without answering questions about your environment
33+
34+
```bash
35+
accelerate config default
36+
```
37+
38+
Or if your environment doesn't support an interactive shell (e.g., a notebook)
39+
40+
```python
41+
from accelerate.utils import write_basic_config
42+
write_basic_config()
43+
```
44+
45+
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
46+
47+
## Circle filling dataset
48+
49+
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
50+
51+
## Training
52+
53+
Our training examples use two test conditioning images. They can be downloaded by running
54+
55+
```sh
56+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
57+
58+
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
59+
```
60+
61+
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
62+
63+
```bash
64+
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9"
65+
export OUTPUT_DIR="path to save model"
66+
67+
accelerate launch train_controlnet_sdxl.py \
68+
--pretrained_model_name_or_path=$MODEL_DIR \
69+
--output_dir=$OUTPUT_DIR \
70+
--dataset_name=fusing/fill50k \
71+
--mixed_precision="fp16" \
72+
--resolution=1024 \
73+
--learning_rate=1e-5 \
74+
--max_train_steps=15000 \
75+
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
76+
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
77+
--validation_steps=100 \
78+
--train_batch_size=1 \
79+
--gradient_accumulation_steps=4 \
80+
--report_to="wandb" \
81+
--seed=42 \
82+
--push_to_hub
83+
```
84+
85+
To better track our training experiments, we're using the following flags in the command above:
86+
87+
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
88+
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
89+
90+
Our experiments were conducted on a single 40GB A100 GPU.
91+
92+
### Inference
93+
94+
Once training is done, we can perform inference like so:
95+
96+
```python
97+
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
98+
from diffusers.utils import load_image
99+
import torch
100+
101+
base_model_path = "stabilityai/stable-diffusion-xl-base-0.9"
102+
controlnet_path = "path to controlnet"
103+
104+
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
105+
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
106+
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
107+
)
108+
109+
# speed up diffusion process with faster scheduler and memory optimization
110+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
111+
# remove following line if xformers is not installed or when using Torch 2.0.
112+
pipe.enable_xformers_memory_efficient_attention()
113+
# memory optimization.
114+
pipe.enable_model_cpu_offload()
115+
116+
control_image = load_image("./conditioning_image_1.png")
117+
prompt = "pale golden rod circle with old lace background"
118+
119+
# generate image
120+
generator = torch.manual_seed(0)
121+
image = pipe(
122+
prompt, num_inference_steps=20, generator=generator, image=control_image
123+
).images[0]
124+
image.save("./output.png")
125+
```
126+
127+
## Notes
128+
129+
### Specifying a better VAE
130+
131+
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
accelerate>=0.16.0
2+
torchvision
3+
transformers>=4.25.1
4+
ftfy
5+
tensorboard
6+
Jinja2
7+
invisible-watermark>=0.2.0
8+
datasets
9+
wandb

0 commit comments

Comments
 (0)