From 2e27eb338d7fe6fd6205f2cfa1011977794d06d0 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Sat, 6 May 2023 12:38:54 +0000 Subject: [PATCH 01/16] Improve checkpointing lora --- examples/dreambooth/train_dreambooth_lora.py | 54 ++++++++++++++++---- src/diffusers/loaders.py | 10 ++-- 2 files changed, 50 insertions(+), 14 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9af81aa5a95d..a60192570c8b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -27,6 +27,7 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers +import accelerate from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed @@ -746,6 +747,45 @@ def main(args): accelerator.register_for_checkpointing(text_encoder_lora_layers) del temp_pipeline + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + for model in models: + state_dict = model.state_dict() + + if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_lora_layers.state_dict().keys(): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_lora_layers.state_dict().keys(): + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save) + + def load_model_hook(models, input_dir): + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=text_encoder if args.train_text_encoder else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + del temp_pipeline + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -959,17 +999,9 @@ def main(args): global_step += 1 if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - # We combine the text encoder and UNet LoRA parameters with a simple - # custom logic. `accelerator.save_state()` won't know that. So, - # use `LoraLoaderMixin.save_lora_weights()`. - LoraLoaderMixin.save_lora_weights( - save_directory=save_path, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, - ) - logger.info(f"Saved state to {save_path}") + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index b4b0f4bb3bd6..a88919ec9b27 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1059,7 +1059,7 @@ def _load_text_encoder_attn_procs( def save_lora_weights( self, save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, @@ -1106,15 +1106,19 @@ def save_function(weights, filename): # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: + weights = unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + unet_lora_state_dict = { f"{self.unet_name}.{module_name}": param - for module_name, param in unet_lora_layers.state_dict().items() + for module_name, param in weights.items() } state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: + weights = text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else unet_lora_layers + text_encoder_lora_state_dict = { f"{self.text_encoder_name}.{module_name}": param - for module_name, param in text_encoder_lora_layers.state_dict().items() + for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) From 7a7ffd35d188919a0f2007a3f5cc245768f285c7 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 11:08:06 +0000 Subject: [PATCH 02/16] fix more --- examples/dreambooth/train_dreambooth.py | 54 ++++++++------- examples/dreambooth/train_dreambooth_lora.py | 70 +++++++++++--------- src/diffusers/loaders.py | 33 ++++++++- 3 files changed, 93 insertions(+), 64 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 190f4625a16c..1080f1afbc56 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -733,36 +733,34 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for model in models: - sub_dir = "unet" if type(model) == type(unet) else "text_encoder" - model.save_pretrained(os.path.join(output_dir, sub_dir)) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - while len(models) > 0: - # pop models so that they are not loaded again - model = models.pop() - - if type(model) == type(text_encoder): - # load transformers style into model - load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") - model.config = load_model.config - else: - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if type(model) == type(unet) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if type(model) == type(text_encoder): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) - model.load_state_dict(load_model.state_dict()) - del load_model + model.load_state_dict(load_model.state_dict()) + del load_model - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index a60192570c8b..f00a88719b89 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -725,7 +725,6 @@ def main(args): unet.set_attn_processor(unet_lora_attn_procs) unet_lora_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(unet_lora_layers) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, @@ -744,47 +743,52 @@ def main(args): ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) text_encoder = temp_pipeline.text_encoder - accelerator.register_for_checkpointing(text_encoder_lora_layers) del temp_pipeline - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - # there are only two options here. Either are just the unet attn processor layers - # or there are the unet and text encoder atten layers - unet_lora_layers_to_save = None - text_encoder_lora_layers_to_save = None + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None - for model in models: - state_dict = model.state_dict() + for model in models: + state_dict = model.state_dict() - if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_lora_layers.state_dict().keys(): - # text encoder - text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_lora_layers.state_dict().keys(): - # unet - unet_lora_layers_to_save = state_dict + if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_lora_layers.state_dict().keys(): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_lora_layers.state_dict().keys(): + # unet + unet_lora_layers_to_save = state_dict - # make sure to pop weight so that corresponding model is not saved again - weights.pop() + # make sure to pop weight so that corresponding model is not saved again + weights.pop() - LoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save) + LoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save) - def load_model_hook(models, input_dir): - temp_pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=text_encoder if args.train_text_encoder else None, - revision=args.revision, - torch_dtype=weight_dtype, - ) - temp_pipeline.load_lora_weights(input_dir) - del temp_pipeline + def load_model_hook(models, input_dir): + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder) if args.train_text_encoder else None, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a88919ec9b27..409f37f3925c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -66,6 +66,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder + self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] + # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): @@ -77,10 +80,17 @@ def map_to(module, state_dict, *args, **kwargs): return new_state_dict + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + + raise ValueError(f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}.") + def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" + replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] @@ -843,10 +853,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } + attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) if len(text_encoder_lora_state_dict) > 0: - attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) + # save lora attn procs of text encoder so that it can be easily retrieved + self._text_encoder_lora_attn_procs = attn_procs_text_encoder + # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( @@ -856,6 +869,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + @property + def text_encoder_lora_attn_procs(self): + if hasattr(self, "_text_encoder_lora_attn_procs"): + return self._text_encoder_lora_attn_procs + return + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" Monkey-patches the forward passes of attention modules of the text encoder. @@ -880,6 +899,13 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward + @property + def text_encoder_lora_procs(self): + for name, _ in self.text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + # Retrieve the module and its corresponding LoRA processor. + module = self.text_encoder.get_submodule(name) + def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" @@ -1113,8 +1139,9 @@ def save_function(weights, filename): for module_name, param in weights.items() } state_dict.update(unet_lora_state_dict) + if text_encoder_lora_layers is not None: - weights = text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else unet_lora_layers + weights = text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else text_encoder_lora_layers text_encoder_lora_state_dict = { f"{self.text_encoder_name}.{module_name}": param From b462b4ec7e0be1828cc7ceb054cd31f1dfcd3495 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 11:12:01 +0000 Subject: [PATCH 03/16] Improve doc string --- src/diffusers/loaders.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 409f37f3925c..17e1ef22ada3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1098,13 +1098,13 @@ def save_lora_weights( Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module`]): + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the - serialization process easier and cleaner. - text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch weights. + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state - dict. + dict. Values can be both LoRA torch.nn.Modules layers or torch weights. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on From 3e81c47f6435412ed8e7cb0537f68e8889db52f5 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:13:23 +0100 Subject: [PATCH 04/16] Update src/diffusers/loaders.py --- src/diffusers/loaders.py | 7 ------- 1 file changed, 7 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 17e1ef22ada3..c171d641dee6 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -899,13 +899,6 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward - @property - def text_encoder_lora_procs(self): - for name, _ in self.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - # Retrieve the module and its corresponding LoRA processor. - module = self.text_encoder.get_submodule(name) - def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" From 8d52ed795d3330faba14e3060146acf39a1f074d Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 11:14:42 +0000 Subject: [PATCH 05/16] make stytle --- examples/dreambooth/train_dreambooth.py | 1 - examples/dreambooth/train_dreambooth_lora.py | 14 +++++++--- src/diffusers/loaders.py | 27 ++++++++++++-------- 3 files changed, 26 insertions(+), 16 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1080f1afbc56..8792d4f6c1b9 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -22,7 +22,6 @@ import warnings from pathlib import Path -import accelerate import numpy as np import torch import torch.nn.functional as F diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index f00a88719b89..1e6d4149957e 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -27,7 +27,6 @@ import torch.nn.functional as F import torch.utils.checkpoint import transformers -import accelerate from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration, set_seed @@ -755,17 +754,24 @@ def save_model_hook(models, weights, output_dir): for model in models: state_dict = model.state_dict() - if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_lora_layers.state_dict().keys(): + if ( + text_encoder_lora_layers is not None + and state_dict.keys() == text_encoder_lora_layers.state_dict().keys() + ): # text encoder text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_lora_layers.state_dict().keys(): + elif state_dict.keys() == unet_lora_layers.state_dict().keys(): # unet unet_lora_layers_to_save = state_dict # make sure to pop weight so that corresponding model is not saved again weights.pop() - LoraLoaderMixin.save_lora_weights(output_dir, unet_lora_layers=unet_lora_layers_to_save, text_encoder_lora_layers=text_encoder_lora_layers_to_save) + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) def load_model_hook(models, input_dir): temp_pipeline = DiffusionPipeline.from_pretrained( diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 17e1ef22ada3..019856aa81c3 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -85,7 +85,9 @@ def remap_key(key, state_dict): if k in key: return key.split(k)[0] + k - raise ValueError(f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}.") + raise ValueError( + f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." + ) def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) @@ -904,7 +906,7 @@ def text_encoder_lora_procs(self): for name, _ in self.text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): # Retrieve the module and its corresponding LoRA processor. - module = self.text_encoder.get_submodule(name) + self.text_encoder.get_submodule(name) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: @@ -1100,7 +1102,8 @@ def save_lora_weights( Directory to which to save. Will be created if it doesn't exist. unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the - serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch weights. + serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch + weights. text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state @@ -1132,20 +1135,22 @@ def save_function(weights, filename): # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: - weights = unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + weights = ( + unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + ) - unet_lora_state_dict = { - f"{self.unet_name}.{module_name}": param - for module_name, param in weights.items() - } + unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) if text_encoder_lora_layers is not None: - weights = text_encoder_lora_layers.state_dict() if isinstance(text_encoder_lora_layers, torch.nn.Module) else text_encoder_lora_layers + weights = ( + text_encoder_lora_layers.state_dict() + if isinstance(text_encoder_lora_layers, torch.nn.Module) + else text_encoder_lora_layers + ) text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param - for module_name, param in weights.items() + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict) From d6c4872d5024734d5224227a0a3f24ba5ea18eec Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:16:54 +0100 Subject: [PATCH 06/16] Apply suggestions from code review --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 019856aa81c3..474aac8b114a 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -855,8 +855,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di text_encoder_lora_state_dict = { k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } - attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) if len(text_encoder_lora_state_dict) > 0: + attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved From d4b6502bd7dc96d42d6d807232eaae8a64a83315 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:17:37 +0100 Subject: [PATCH 07/16] Update src/diffusers/loaders.py --- src/diffusers/loaders.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 474aac8b114a..5245d4a181a4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -900,14 +900,6 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward - - @property - def text_encoder_lora_procs(self): - for name, _ in self.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - # Retrieve the module and its corresponding LoRA processor. - self.text_encoder.get_submodule(name) - def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" From 22fff9ddd717ef377d4fb0677b0400b6940ff70f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:18:00 +0100 Subject: [PATCH 08/16] Apply suggestions from code review --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5245d4a181a4..d605a99d48c8 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -900,6 +900,7 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward + def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" From 9b38e2c702b0c7a5c989c707a9e7e13fc0852e3b Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:18:21 +0100 Subject: [PATCH 09/16] Apply suggestions from code review --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index d605a99d48c8..5245d4a181a4 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -900,7 +900,6 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward - def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" From 25c937a37697c91d41f8b85ffac226039a4692c6 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 11:18:48 +0000 Subject: [PATCH 10/16] better --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5245d4a181a4..606abad97227 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -900,6 +900,7 @@ def new_forward(x): # Monkey-patch. module.forward = new_forward + def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: return "to_q_lora" From 4874e4926e8c2905cb8c428ac63069734843587f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 11:49:23 +0000 Subject: [PATCH 11/16] Fix all --- examples/dreambooth/train_dreambooth_lora.py | 27 ++++++++++++-------- 1 file changed, 16 insertions(+), 11 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 1e6d4149957e..073e68b2f838 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -754,13 +754,13 @@ def save_model_hook(models, weights, output_dir): for model in models: state_dict = model.state_dict() - if ( - text_encoder_lora_layers is not None - and state_dict.keys() == text_encoder_lora_layers.state_dict().keys() - ): + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_keys: # text encoder text_encoder_lora_layers_to_save = state_dict - elif state_dict.keys() == unet_lora_layers.state_dict().keys(): + elif state_dict.keys() == unet_keys: # unet unet_lora_layers_to_save = state_dict @@ -774,10 +774,13 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder) if args.train_text_encoder else None, revision=args.revision, torch_dtype=weight_dtype, ) @@ -1008,10 +1011,11 @@ def load_model_hook(models, input_dir): progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - accelerator.save_state(save_path) - logger.info(f"Saved state to {save_path}") + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) @@ -1085,6 +1089,7 @@ def load_model_hook(models, input_dir): pipeline.load_lora_weights(args.output_dir) # run inference + images = [] if args.validation_prompt and args.num_validation_images > 0: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ From c3b5f53e9cbc90f93a3c7a92c224b0383b4f1bb9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 8 May 2023 12:00:11 +0000 Subject: [PATCH 12/16] Fix multi-GPU dreambooth --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 8792d4f6c1b9..54078c1cf06d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -735,7 +735,7 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for model in models: - sub_dir = "unet" if type(model) == type(unet) else "text_encoder" + sub_dir = "unet" if type(model) == type(accelerator.unwrap_model(unet)) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again @@ -746,7 +746,7 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - if type(model) == type(text_encoder): + if type(model) == type(accelerator.unwrap_model(text_encoder)): # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config From 97a140b35bcb46aa37c180bf933a64c855718f2a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 18:33:08 +0100 Subject: [PATCH 13/16] Apply suggestions from code review Co-authored-by: Pedro Cuenca --- examples/dreambooth/train_dreambooth_lora.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 073e68b2f838..60957d683e93 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -751,13 +751,14 @@ def save_model_hook(models, weights, output_dir): unet_lora_layers_to_save = None text_encoder_lora_layers_to_save = None + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + for model in models: state_dict = model.state_dict() - text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() - unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() - - if text_encoder_lora_layers is not None and state_dict.keys() == text_encoder_keys: + if text_encoder_lora_layers is not None and text_encoder_keys is not None and state_dict.keys() == text_encoder_keys: # text encoder text_encoder_lora_layers_to_save = state_dict elif state_dict.keys() == unet_keys: @@ -774,7 +775,7 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - # Note we DON'T pass the unet and text encoder here an purpose + # Note we DON'T pass the unet and text encoder here on purpose # so that the we don't accidentally override the LoRA layers of # unet_lora_layers and text_encoder_lora_layers which are stored in `models` # with new torch.nn.Modules / weights. We simply use the pipeline class as From 27c7191da27d5af5e833c6864af9346da43810cf Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 17:50:46 +0000 Subject: [PATCH 14/16] Fix all --- examples/dreambooth/train_dreambooth.py | 4 ++-- examples/dreambooth/train_dreambooth_lora.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 54078c1cf06d..5d2107f024d1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -735,7 +735,7 @@ def main(args): # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format def save_model_hook(models, weights, output_dir): for model in models: - sub_dir = "unet" if type(model) == type(accelerator.unwrap_model(unet)) else "text_encoder" + sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" model.save_pretrained(os.path.join(output_dir, sub_dir)) # make sure to pop weight so that corresponding model is not saved again @@ -746,7 +746,7 @@ def load_model_hook(models, input_dir): # pop models so that they are not loaded again model = models.pop() - if type(model) == type(accelerator.unwrap_model(text_encoder)): + if isinstance(model, type(accelerator.unwrap_model(text_encoder))): # load transformers style into model load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") model.config = load_model.config diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 073e68b2f838..b940c68a84d1 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -774,10 +774,10 @@ def save_model_hook(models, weights, output_dir): ) def load_model_hook(models, input_dir): - # Note we DON'T pass the unet and text encoder here an purpose + # Note we DON'T pass the unet and text encoder here an purpose # so that the we don't accidentally override the LoRA layers of # unet_lora_layers and text_encoder_lora_layers which are stored in `models` - # with new torch.nn.Modules / weights. We simply use the pipeline class as + # with new torch.nn.Modules / weights. We simply use the pipeline class as # an easy way to load the lora checkpoints temp_pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, @@ -787,7 +787,7 @@ def load_model_hook(models, input_dir): temp_pipeline.load_lora_weights(input_dir) # load lora weights into models - models[0].load_state_dict(AttnProcsLayers(unet.attn_processors).state_dict()) + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) if len(models) > 1: models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) @@ -1071,6 +1071,10 @@ def load_model_hook(models, input_dir): if accelerator.is_main_process: unet = unet.to(torch.float32) text_encoder = text_encoder.to(torch.float32) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + text_encoder_lora_layers = ( + accelerator.unwrap_model(text_encoder_lora_layers) if args.train_text_encoder else None + ) LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, From df87aaf6b9df529eaeaa0b7d3fead23b8b6c25bb Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 17:55:00 +0000 Subject: [PATCH 15/16] make style --- examples/dreambooth/train_dreambooth_lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 09387ef11eb2..b1b39c52cc07 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -758,7 +758,11 @@ def save_model_hook(models, weights, output_dir): for model in models: state_dict = model.state_dict() - if text_encoder_lora_layers is not None and text_encoder_keys is not None and state_dict.keys() == text_encoder_keys: + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): # text encoder text_encoder_lora_layers_to_save = state_dict elif state_dict.keys() == unet_keys: From eb4e4560809ccf1bf18d69c70d6f1b478a67c025 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 11 May 2023 17:59:37 +0000 Subject: [PATCH 16/16] make style --- examples/dreambooth/train_dreambooth_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index d80bbcf1f35c..16adfe4b83fc 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -1269,7 +1269,7 @@ def compute_text_embeddings(prompt): if accelerator.is_main_process: unet = unet.to(torch.float32) unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) - + if text_encoder is not None: text_encoder = text_encoder.to(torch.float32) text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)