Skip to content

Commit ed616bd

Browse files
patrickvonplatencloneofsimopcuencapatil-suraj
authored
[LoRA] Add LoRA training script (#1884)
* [Lora] first upload * add first lora version * upload * more * first training * up * correct * improve * finish loaders and inference * up * up * fix more * up * finish more * finish more * up * up * change year * revert year change * Change lines * Add cloneofsimo as co-author. Co-authored-by: Simo Ryu <[email protected]> * finish * fix docs * Apply suggestions from code review Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]> * upload * finish Co-authored-by: Simo Ryu <[email protected]> Co-authored-by: Pedro Cuenca <[email protected]> Co-authored-by: Suraj Patil <[email protected]>
1 parent ac3fc64 commit ed616bd

24 files changed

+2287
-663
lines changed

docs/source/en/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,8 @@
9090
title: Configuration
9191
- local: api/outputs
9292
title: Outputs
93+
- local: api/loaders
94+
title: Loaders
9395
title: Main Classes
9496
- sections:
9597
- local: api/pipelines/overview

docs/source/en/api/loaders.mdx

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
-->
12+
13+
# Loaders
14+
15+
There are many weights to train adapter neural networks for diffusion models, such as
16+
- [Textual Inversion](./training/text_inversion.mdx)
17+
- [LoRA](https://github.com/cloneofsimo/lora)
18+
- [Hypernetworks](https://arxiv.org/abs/1609.09106)
19+
20+
Such adapter neural networks often only consist of a fraction of the number of weights compared
21+
to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use
22+
API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py).
23+
24+
**Note**: This module is still highly experimental and prone to future changes.
25+
26+
## LoaderMixins
27+
28+
### UNet2DConditionLoadersMixin
29+
30+
[[autodoc]] loaders.UNet2DConditionLoadersMixin

docs/source/en/api/logging.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
1+
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
22

33
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
44
the License. You may obtain a copy of the License at

examples/dreambooth/README.md

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a
55

66

77
## Running locally with PyTorch
8+
89
### Installing the dependencies
910

1011
Before running the scripts, make sure to install the library's training dependencies:
@@ -235,6 +236,102 @@ image.save("dog-bucket.png")
235236

236237
You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
237238

239+
## Training with Low-Rank Adaptation of Large Language Models (LoRA)
240+
241+
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
242+
243+
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
244+
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
245+
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
246+
- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter.
247+
248+
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
249+
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
250+
251+
### Training
252+
253+
Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
254+
255+
First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
256+
Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
257+
258+
Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
259+
260+
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
261+
262+
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
263+
264+
265+
```bash
266+
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
267+
export INSTANCE_DIR="path-to-instance-images"
268+
export OUTPUT_DIR="path-to-save-model"
269+
```
270+
271+
For this example we want to directly store the trained LoRA embeddings on the Hub, so
272+
we need to be logged in and add the `--push_to_hub` flag.
273+
274+
```bash
275+
huggingface-cli login
276+
```
277+
278+
Now we can start training!
279+
280+
```bash
281+
accelerate launch train_dreambooth_lora.py \
282+
--pretrained_model_name_or_path=$MODEL_NAME \
283+
--instance_data_dir=$INSTANCE_DIR \
284+
--output_dir=$OUTPUT_DIR \
285+
--instance_prompt="a photo of sks dog" \
286+
--resolution=512 \
287+
--train_batch_size=1 \
288+
--gradient_accumulation_steps=1 \
289+
--checkpointing_steps=100 \
290+
--learning_rate=1e-4 \
291+
--report_to="wandb" \
292+
--lr_scheduler="constant" \
293+
--lr_warmup_steps=0 \
294+
--max_train_steps=500 \
295+
--validation_prompt="A photo of sks dog in a bucket" \
296+
--validation_epochs=50 \
297+
--seed="0" \
298+
--push_to_hub
299+
```
300+
301+
**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
302+
use *1e-4* instead of the usual *2e-6*.___**
303+
304+
The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
305+
306+
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
307+
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
308+
309+
### Inference
310+
311+
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
312+
load the original pipeline:
313+
314+
```python
315+
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
316+
import torch
317+
318+
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
319+
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
320+
pipe.to("cuda")
321+
```
322+
323+
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
324+
325+
```python
326+
pipe.load_attn_procs("patrickvonplaten/lora")
327+
```
328+
329+
Finally, we can run the model in inference.
330+
331+
```python
332+
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
333+
```
334+
238335
## Training with Flax/JAX
239336

240337
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

examples/dreambooth/train_dreambooth.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,18 @@
1+
#!/usr/bin/env python
2+
# coding=utf-8
3+
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
116
import argparse
217
import hashlib
318
import itertools

0 commit comments

Comments
 (0)