-
Notifications
You must be signed in to change notification settings - Fork 6k
[Core] add: controlnet support for SDXL #4038
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 48 commits
Commits
Show all changes
54 commits
Select commit
Hold shift + click to select a range
c6c9f3a
add: controlnet sdxl.
sayakpaul 185b67b
modifications to controlnet.
sayakpaul db78a4c
run styling.
sayakpaul af8273a
add: __init__.pys
sayakpaul c8b00de
incorporate https://github.com/huggingface/diffusers/pull/4019 changes.
sayakpaul 68f2c38
run make fix-copies.
sayakpaul 4b86d63
resize the conditioning images.
sayakpaul dade681
remove autocast.
sayakpaul 15d4afd
run styling.
sayakpaul 5ed7d3e
disable autocast.
sayakpaul 64b0e20
debugging
sayakpaul f482f46
device placement.
sayakpaul c2bbb2b
back to autocast.
sayakpaul ccb6210
remove comment.
sayakpaul be13ef5
save some memory by reusing the vae and unet in the pipeline.
sayakpaul fbb086a
apply styling.
sayakpaul 65a5c45
Allow low precision sd xl
patrickvonplaten 43f842c
finish
patrickvonplaten 7171d42
finish
patrickvonplaten 97f69a7
Merge branch 'main' into allow_low_precision_vae_sd_xl
patrickvonplaten d706b2b
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul e2ca903
Merge remote-tracking branch 'origin/allow_low_precision_vae_sd_xl' i…
sayakpaul fa22868
changes to accommodate the improved VAE.
sayakpaul b1c83fd
modifications to how we handle vae encoding in the training.
sayakpaul a4e10d8
make style
sayakpaul 2d7c454
make existing controlnet fast tests pass.
sayakpaul 7296a7f
change vae checkpoint cli arg.
sayakpaul 74f7485
fix: vae pretrained paths.
sayakpaul 73ccae8
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul 7d27b1d
fix: steps in get_scheduler().
sayakpaul 3fa4809
debugging.
sayakpaul 115b30d
Merge branch 'feat/sd-xl-controlnet-2' of https://github.com/huggingf…
sayakpaul 523c1eb
debugging./
sayakpaul 98ed3c9
fix: weight conversion.
sayakpaul a397ead
add: docs.
sayakpaul 4c14e88
add: limited tests./
sayakpaul 156dd27
add: datasets to the requirements.
sayakpaul 02a9d88
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul 6611259
update docstrings and incorporate the usage of watermarking.
sayakpaul 695c3ac
incorporate fix from #4083
sayakpaul 680e241
fix watermarking dependency handling.
sayakpaul 842f5dd
run make-fix-copies.
sayakpaul 0c904ce
Empty-Commit
sayakpaul ae48efc
Update requirements_sdxl.txt
sayakpaul 74dda65
remove vae upcasting part.
sayakpaul 08153a9
Apply suggestions from code review
sayakpaul dc4839e
run make style
sayakpaul 54fdb56
run make fix-copies.
sayakpaul 5d53cb8
disable suppot for multicontrolnet.
sayakpaul e430d1a
Merge branch 'main' into feat/sd-xl-controlnet-2
sayakpaul 740944a
Apply suggestions from code review
sayakpaul e35cf2b
run make fix-copies.
sayakpaul d7aecd2
dtyle/.
sayakpaul f129bc4
fix-copies.
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
# DreamBooth training example for Stable Diffusion XL (SDXL) | ||
|
||
The `train_controlnet_sdxl.py` script shows how to implement the training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952). | ||
|
||
## Running locally with PyTorch | ||
|
||
### Installing the dependencies | ||
|
||
Before running the scripts, make sure to install the library's training dependencies: | ||
|
||
**Important** | ||
|
||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: | ||
|
||
```bash | ||
git clone https://github.com/huggingface/diffusers | ||
cd diffusers | ||
pip install -e . | ||
``` | ||
|
||
Then cd in the `examples/controlnet` folder and run | ||
```bash | ||
pip install -r requirements_sdxl.txt | ||
``` | ||
|
||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: | ||
|
||
```bash | ||
accelerate config | ||
``` | ||
|
||
Or for a default accelerate configuration without answering questions about your environment | ||
|
||
```bash | ||
accelerate config default | ||
``` | ||
|
||
Or if your environment doesn't support an interactive shell (e.g., a notebook) | ||
|
||
```python | ||
from accelerate.utils import write_basic_config | ||
write_basic_config() | ||
``` | ||
|
||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. | ||
|
||
## Circle filling dataset | ||
|
||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. | ||
|
||
## Training | ||
|
||
Our training examples use two test conditioning images. They can be downloaded by running | ||
|
||
```sh | ||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png | ||
|
||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png | ||
``` | ||
|
||
Then run `huggingface-cli login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub. | ||
|
||
```bash | ||
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-0.9" | ||
export OUTPUT_DIR="path to save model" | ||
|
||
accelerate launch train_controlnet_sdxl.py \ | ||
--pretrained_model_name_or_path=$MODEL_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--dataset_name=fusing/fill50k \ | ||
--mixed_precision="fp16" \ | ||
--resolution=1024 \ | ||
--learning_rate=1e-5 \ | ||
--max_train_steps=15000 \ | ||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ | ||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ | ||
--validation_steps=100 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=4 \ | ||
--report_to="wandb" \ | ||
--seed=42 \ | ||
--push_to_hub | ||
``` | ||
|
||
To better track our training experiments, we're using the following flags in the command above: | ||
|
||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`. | ||
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. | ||
|
||
Our experiments were conducted on a single 40GB A100 GPU. | ||
|
||
### Inference | ||
|
||
Once training is done, we can perform inference like so: | ||
|
||
```python | ||
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | ||
from diffusers.utils import load_image | ||
import torch | ||
|
||
base_model_path = "stabilityai/stable-diffusion-xl-base-0.9" | ||
controlnet_path = "path to controlnet" | ||
|
||
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) | ||
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | ||
base_model_path, controlnet=controlnet, torch_dtype=torch.float16 | ||
) | ||
|
||
# speed up diffusion process with faster scheduler and memory optimization | ||
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | ||
# remove following line if xformers is not installed or when using Torch 2.0. | ||
pipe.enable_xformers_memory_efficient_attention() | ||
# memory optimization. | ||
pipe.enable_model_cpu_offload() | ||
|
||
control_image = load_image("./conditioning_image_1.png") | ||
prompt = "pale golden rod circle with old lace background" | ||
|
||
# generate image | ||
generator = torch.manual_seed(0) | ||
image = pipe( | ||
prompt, num_inference_steps=20, generator=generator, image=control_image | ||
).images[0] | ||
image.save("./output.png") | ||
``` | ||
|
||
## Notes | ||
|
||
### Specifying a better VAE | ||
|
||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)). |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
accelerate>=0.16.0 | ||
torchvision | ||
transformers>=4.25.1 | ||
ftfy | ||
tensorboard | ||
Jinja2 | ||
invisible-watermark>=0.2.0 | ||
datasets | ||
wandb |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.