Skip to content

[Feature] Finetune text encoder in train_text_to_image_lora #3912

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 15 commits into from
50 changes: 50 additions & 0 deletions docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,40 @@ accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--seed=1337
```

### Finetuning the text encoder and UNet

The script also allows you to finetune the `text_encoder` along with the `unet`.

<Tip warning={true}>

Training the text encoder requires additional memory and it won't fit on a 16GB GPU. You'll need at least 24GB VRAM to use this option.

</Tip>

Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:

```bash
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--push_to_hub \
--hub_model_id=${HUB_MODEL_ID} \
--report_to=wandb \
--checkpointing_steps=500 \
--validation_prompt="A pokemon with blue eyes." \
--train_text_encoder \
--seed=1337
```

### Inference[[text-to-image-inference]]

Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]:
Expand Down Expand Up @@ -144,6 +178,22 @@ pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.

</Tip>

If you used `--train_text_encoder` during training, then use `pipe.load_lora_weights()` to load the LoRA
weights. For example:

```python
from diffusers import StableDiffusionPipeline
import torch

lora_model_id = "takuoko/classic-anime-expressions-lora"
base_model_id = "stablediffusionapi/anything-v5"

pipe = StableDiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.load_lora_weights(lora_model_id, weight_name="pytorch_lora_weights.bin")
image = pipe("1girl, >_<", num_inference_steps=50).images[0]
```


## DreamBooth

Expand Down
41 changes: 41 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,6 +849,47 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multip
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
)

def test_text_to_image_lora_with_text_encoder(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora.py
--pretrained_model_name_or_path {pretrained_model_name_or_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--center_crop
--random_flip
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
--seed=0
--train_text_encoder
--num_validation_images=0
""".split()

run_command(self._launch_args + initial_run_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))

# check `text_encoder` is present at all.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_text_encoder_present = any(k.startswith("text_encoder") for k in keys)
self.assertTrue(is_text_encoder_present)

# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)

def test_unconditional_checkpointing_checkpoints_total_limit(self):
with tempfile.TemporaryDirectory() as tmpdir:
initial_run_args = f"""
Expand Down
Loading