Skip to content

Fix various bugs with LoRA Dreambooth and Dreambooth script #3353

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 26 additions & 29 deletions examples/dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import warnings
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -733,36 +732,34 @@ def main(args):
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min version of accelerate is already 0.16 => let's remove this check here

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "unet" if type(model) == type(unet) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()

if type(model) == type(text_encoder):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
for model in models:
sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder"
model.save_pretrained(os.path.join(output_dir, sub_dir))

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()

if isinstance(model, type(accelerator.unwrap_model(text_encoder))):
# load transformers style into model
load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder")
model.config = load_model.config
else:
# load diffusers style into model
load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)

model.load_state_dict(load_model.state_dict())
del load_model
model.load_state_dict(load_model.state_dict())
del load_model

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

vae.requires_grad_(False)
if not args.train_text_encoder:
Expand Down
80 changes: 68 additions & 12 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,7 +834,6 @@ def main(args):

unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(unet_lora_layers)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We now save LoRA layers in better format


# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
Expand All @@ -853,9 +852,68 @@ def main(args):
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
accelerator.register_for_checkpointing(text_encoder_lora_layers)
del temp_pipeline

# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make sure the lora layers are saved in a more user-friendly way here

# there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None
text_encoder_lora_layers_to_save = None

if args.train_text_encoder:
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()

for model in models:
state_dict = model.state_dict()

if (
text_encoder_lora_layers is not None
and text_encoder_keys is not None
and state_dict.keys() == text_encoder_keys
):
# text encoder
text_encoder_lora_layers_to_save = state_dict
elif state_dict.keys() == unet_keys:
# unet
unet_lora_layers_to_save = state_dict

# make sure to pop weight so that corresponding model is not saved again
weights.pop()

LoraLoaderMixin.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
)

def load_model_hook(models, input_dir):
# Note we DON'T pass the unet and text encoder here an purpose
# so that the we don't accidentally override the LoRA layers of
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
# with new torch.nn.Modules / weights. We simply use the pipeline class as
# an easy way to load the lora checkpoints
temp_pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
revision=args.revision,
torch_dtype=weight_dtype,
)
temp_pipeline.load_lora_weights(input_dir)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean and tight 🎸


# load lora weights into models
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
if len(models) > 1:
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())

# delete temporary pipeline and pop models
del temp_pipeline
for _ in range(len(models)):
models.pop()

accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)

# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
Expand Down Expand Up @@ -1130,17 +1188,10 @@ def compute_text_embeddings(prompt):
progress_bar.update(1)
global_step += 1

if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
# We combine the text encoder and UNet LoRA parameters with a simple
# custom logic. `accelerator.save_state()` won't know that. So,
# use `LoraLoaderMixin.save_lora_weights()`.
LoraLoaderMixin.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand Down Expand Up @@ -1217,8 +1268,12 @@ def compute_text_embeddings(prompt):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)

if text_encoder is not None:
text_encoder = text_encoder.to(torch.float32)
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)

LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
Expand Down Expand Up @@ -1250,6 +1305,7 @@ def compute_text_embeddings(prompt):
pipeline.load_lora_weights(args.output_dir)

# run inference
images = []
if args.validation_prompt and args.num_validation_images > 0:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
images = [
Expand Down
53 changes: 41 additions & 12 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]

# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
Expand All @@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs):

return new_state_dict

def remap_key(key, state_dict):
for k in self.split_keys:
if k in key:
return key.split(k)[0] + k

raise ValueError(
f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
)

def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = key.split(".processor")[0] + ".processor"
replace_key = remap_key(key, state_dict)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to be able to split this also according to other keys

new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
Expand Down Expand Up @@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
self._modify_text_encoder(attn_procs_text_encoder)

# save lora attn procs of text encoder so that it can be easily retrieved
self._text_encoder_lora_attn_procs = attn_procs_text_encoder

# Otherwise, we're dealing with the old format. This means the `state_dict` should only
# contain the module names of the `unet` as its keys WITHOUT any prefix.
elif not all(
Expand All @@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`."
warnings.warn(warn_message)

@property
def text_encoder_lora_attn_procs(self):
if hasattr(self, "_text_encoder_lora_attn_procs"):
return self._text_encoder_lora_attn_procs
return

def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
r"""
Monkey-patches the forward passes of attention modules of the text encoder.
Expand Down Expand Up @@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs(
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, torch.nn.Module] = None,
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's allow to also pass the lora layers as low level torch.Tensors state dicts

text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
is_main_process: bool = True,
weight_name: str = None,
Expand All @@ -1123,13 +1144,14 @@ def save_lora_weights(
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
unet_lora_layers (`Dict[str, torch.nn.Module`]):
unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the
serialization process easier and cleaner.
text_encoder_lora_layers (`Dict[str, torch.nn.Module`]):
serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch
weights.
text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`):
State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from
`transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state
dict.
dict. Values can be both LoRA torch.nn.Modules layers or torch weights.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
Expand Down Expand Up @@ -1157,15 +1179,22 @@ def save_function(weights, filename):
# Create a flat dictionary.
state_dict = {}
if unet_lora_layers is not None:
unet_lora_state_dict = {
f"{self.unet_name}.{module_name}": param
for module_name, param in unet_lora_layers.state_dict().items()
}
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
)

unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
state_dict.update(unet_lora_state_dict)

if text_encoder_lora_layers is not None:
weights = (
text_encoder_lora_layers.state_dict()
if isinstance(text_encoder_lora_layers, torch.nn.Module)
else text_encoder_lora_layers
)

text_encoder_lora_state_dict = {
f"{self.text_encoder_name}.{module_name}": param
for module_name, param in text_encoder_lora_layers.state_dict().items()
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
}
state_dict.update(text_encoder_lora_state_dict)

Expand Down