-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[docs] Adds a doc on LoRA support for diffusers #2086
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
58a930f
233d646
72814aa
aab46f4
7f23db6
76562ea
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
<!--Copyright 2023 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
--> | ||
|
||
# LoRA Support in Diffusers | ||
|
||
Diffusers support LoRA for Stable Diffusion for faster fine-tuning allowing greater memory efficiency and easier portability. | ||
|
||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in | ||
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. | ||
|
||
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition weight matrices (called **update marrices**) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
to existing weights and **only** training those newly added weights. This has a couple of advantages: | ||
|
||
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114). | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: so far we've only mentioned "update matrices", but not how they work or whether they contain attention layers. Maybe we should very briefly introduce the concept? Something simple like "LoRA matrices are added to the model attention layers and they control ..." could work. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See if the current edits make sense. |
||
|
||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. | ||
|
||
<Tip> | ||
|
||
LoRA also allows us to achieve greater memory efficiency since the pretrained weights are kept frozen, only the LoRA weights are trained, thereby | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
allowing us to run fine-tuning on consumer GPUs like Tesla T4. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
</Tip> | ||
|
||
## Getting started with LoRA for fine-tuning | ||
|
||
Stable Diffusion can be fine-tuned in different ways: | ||
|
||
* [Textual inversion](https://huggingface.co/docs/diffusers/main/en/training/text_inversion) | ||
* [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth) | ||
* [Text2Image fine-tuning](https://huggingface.co/docs/diffusers/main/en/training/text2image) | ||
|
||
We provide two end-to-end examples that show how to run fine-tuning with LoRA: | ||
|
||
* [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora) | ||
* [Text2Image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) | ||
|
||
If you want to perform DreamBooth training with LoRA, for instance, you would run: | ||
|
||
```bash | ||
export MODEL_NAME="runwayml/stable-diffusion-v1-5" | ||
export INSTANCE_DIR="path-to-instance-images" | ||
export OUTPUT_DIR="path-to-save-model" | ||
|
||
accelerate launch train_dreambooth_lora.py \ | ||
--pretrained_model_name_or_path=$MODEL_NAME \ | ||
--instance_data_dir=$INSTANCE_DIR \ | ||
--output_dir=$OUTPUT_DIR \ | ||
--instance_prompt="a photo of sks dog" \ | ||
--resolution=512 \ | ||
--train_batch_size=1 \ | ||
--gradient_accumulation_steps=1 \ | ||
--checkpointing_steps=100 \ | ||
--learning_rate=1e-4 \ | ||
--report_to="wandb" \ | ||
--lr_scheduler="constant" \ | ||
--lr_warmup_steps=0 \ | ||
--max_train_steps=500 \ | ||
--validation_prompt="A photo of sks dog in a bucket" \ | ||
--validation_epochs=50 \ | ||
--seed="0" \ | ||
--push_to_hub | ||
``` | ||
|
||
Refer to the respective examples linked above to learn more. | ||
|
||
<Tip> | ||
|
||
When using LoRA we can use a much higher learning rate (typically 1e-4 as opposed to 1e-5) compared to non-LoRA fine-tuning. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
</Tip> | ||
|
||
But there is no free lunch. For the given dataset and expected generation quality, you'd still need to experiment with | ||
different hyperparameters. Here are some important ones: | ||
|
||
* Training time | ||
* Learning rate | ||
* Number of training steps | ||
* Inference time | ||
* Number of steps | ||
* Scheduler type | ||
|
||
Additionally, you can follow [this blog](https://huggingface.co/blog/dreambooth) that documents some of our experimental | ||
findings for performing DreamBooth training Stable Diffusion. | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
When fine-tuning, the LoRA update matrices are only added to the attention layers. To enable this, we added new weight | ||
loading functionalities. Their details are available [here](https://huggingface.co/docs/diffusers/main/en/api/loaders). | ||
pcuenca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
## Inference | ||
|
||
Assuming, you used the `examples/text_to_image/train_text_to_image_lora.py` to fine-tune Stable Diffusion on the [Pokemons | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
dataset](https://huggingface.co/lambdalabs/pokemon-blip-captions), you can perform inference like so: | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
```py | ||
from diffusers import StableDiffusionPipeline | ||
import torch | ||
|
||
model_path = "sayakpaul/sd-model-finetuned-lora-t4" | ||
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) maybe we can show how to retrieve the base_model from the model card by loading the yaml code via There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. from huggingface_hub.repocard import RepoCard
card = RepoCard.load("sayakpaul/sd-model-finetuned-lora-t4")
card.data.to_dict()["base_model"]
# 'CompVis/stable-diffusion-v1-4' I guess we would want to show it in a separate code snippet from the doc? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! Maybe include it as a tip below the current snippet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For me it's fine in the same code snippet There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See if the current changes make sense. |
||
pipe.unet.load_attn_procs(model_path) | ||
pipe.to("cuda") | ||
|
||
prompt = "A pokemon with green eyes and red legs." | ||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] | ||
image.save("pokemon.png") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just wondering, maybe display the image here? We never do it in the docs, what's your opinion about starting doing it to make things more visual? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Diffusion for computer vision is definitely about visuals. I like the idea and I think we should definitely add it :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an image. |
||
``` | ||
|
||
[`sayakpaul/sd-model-finetuned-lora-t4`](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4) contains [LoRA fine-tuned update matrices](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin) | ||
which is only 3 MBs in size. During inference, the pre-trained Stable Diffusion checkpoints loaded alongside these update | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
matrices and then they are combined to run inference. | ||
|
||
Inference for DreamBooth training remains the same. Check | ||
[this section](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#inference-1) for more details. | ||
|
||
## Known limitations | ||
|
||
* Currently, we only support LoRA for the attention layers of [`UNet2DConditionModel`](https://huggingface.co/docs/diffusers/main/en/api/models#diffusers.UNet2DConditionModel). |
Uh oh!
There was an error while loading. Please reload this page.