Skip to content

[SDXL DreamBooth LoRA] add support for text encoder fine-tuning #4097

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 45 commits into from
Jul 25, 2023

Conversation

sayakpaul
Copy link
Member

@sayakpaul sayakpaul commented Jul 14, 2023

This PR adds support for text encoder fine-tuning in the DreamBooth LoRA script for SDXL.

Summary of the changes:

  • Major refactor of the dataloader and the collator to accommodate for training the text encoders.
  • Support for the numerically stable VAE (some type-casting here and there) (Allow low precision vae sd xl #4083).
  • Changes to the loaders of LoRA.

To help us maintain sanity, I tested the current training script under three settings:

  1. No text encoder and no better VAE
export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained-xl-no-vae-text-encoder"

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=75 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

Artifacts:

  1. No text encoder but better VAE
export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained-xl-no-text-encoder"

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --pretrained_vae_model_name_or_path="sayakpaul/sdxl-vae-fp16-fix-testing" \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=75 \
  --train_text_encoder \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

Artifacts:

  1. Better VAE along with the text encoder
export MODEL_NAME="diffusers/stable-diffusion-xl-base-0.9"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="lora-trained-xl-text-encoder-vae"

accelerate launch train_dreambooth_lora_sdxl.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --pretrained_vae_model_name_or_path="sayakpaul/sdxl-vae-fp16-fix-testing" \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="fp16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=75 \
  --train_text_encoder \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=25 \
  --seed="0" \
  --push_to_hub

Artifacts:

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jul 14, 2023

The documentation is not available anymore as the PR was closed or merged.

@sayakpaul sayakpaul marked this pull request as ready for review July 14, 2023 14:17
@sayakpaul sayakpaul requested review from williamberman and patrickvonplaten and removed request for williamberman July 14, 2023 14:43
@sayakpaul sayakpaul requested a review from williamberman July 21, 2023 09:56
@sayakpaul
Copy link
Member Author

@patrickvonplaten @williamberman I think I have addressed all your comments:

  • Simplification of the dataloader
  • Less state dict munging

I would suggest taking another deeper look.

@@ -809,3 +810,66 @@ def __call__(
return (image,)

return StableDiffusionXLPipelineOutput(images=image)

# Overrride to properly handle the loading and unloading of the additional text encoder.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok for me!

# needed for the SD XL UNet to operate.
def compute_embeddings(prompt, text_encoders, tokenizers):
def compute_time_ids():
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
original_size = (args.resolution, args.resolution)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
original_size = (args.resolution, args.resolution)
original_size = (args.resolution, args.resolution)

This should ideally be the original size of the passed image (before resizing), but ok to leave as is for now

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Co-authored-by: Patrick von Platen <[email protected]>
@patrickvonplaten
Copy link
Contributor

@williamberman ok for you?

Comment on lines +840 to +863
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
state_dict = {}

def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict

state_dict.update(pack_weights(unet_lora_layers, "unet"))

if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

<3

Comment on lines 430 to 436
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
It pre-processes the images.
"""

def __init__(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

c'est magnifique

Copy link
Contributor

@williamberman williamberman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perfect, lgtm!

@sayakpaul sayakpaul merged commit 365e846 into main Jul 25, 2023
@sayakpaul sayakpaul deleted the feat/sdxl-dreambooth-returns branch July 25, 2023 00:05
@sayakpaul
Copy link
Member Author

Thanks all for your suggestions.

orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…ingface#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…ingface#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
orpatashnik pushed a commit to orpatashnik/diffusers that referenced this pull request Aug 1, 2023
…ingface#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
…ingface#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
…ingface#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <[email protected]>

---------

Co-authored-by: Patrick von Platen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants