diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 190f4625a16c..5d2107f024d1 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -22,7 +22,6 @@ import warnings from pathlib import Path -import accelerate import numpy as np import torch import torch.nn.functional as F @@ -733,36 +732,34 @@ def main(args): args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision ) - # `accelerate` 0.16.0 will have better support for customized saving - if version.parse(accelerate.__version__) >= version.parse("0.16.0"): - # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format - def save_model_hook(models, weights, output_dir): - for model in models: - sub_dir = "unet" if type(model) == type(unet) else "text_encoder" - model.save_pretrained(os.path.join(output_dir, sub_dir)) - - # make sure to pop weight so that corresponding model is not saved again - weights.pop() - - def load_model_hook(models, input_dir): - while len(models) > 0: - # pop models so that they are not loaded again - model = models.pop() - - if type(model) == type(text_encoder): - # load transformers style into model - load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") - model.config = load_model.config - else: - # load diffusers style into model - load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") - model.register_to_config(**load_model.config) + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + for model in models: + sub_dir = "unet" if isinstance(model, type(accelerator.unwrap_model(unet))) else "text_encoder" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + if isinstance(model, type(accelerator.unwrap_model(text_encoder))): + # load transformers style into model + load_model = text_encoder_cls.from_pretrained(input_dir, subfolder="text_encoder") + model.config = load_model.config + else: + # load diffusers style into model + load_model = UNet2DConditionModel.from_pretrained(input_dir, subfolder="unet") + model.register_to_config(**load_model.config) - model.load_state_dict(load_model.state_dict()) - del load_model + model.load_state_dict(load_model.state_dict()) + del load_model - accelerator.register_save_state_pre_hook(save_model_hook) - accelerator.register_load_state_pre_hook(load_model_hook) + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) vae.requires_grad_(False) if not args.train_text_encoder: diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 0bf3333a6209..16adfe4b83fc 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -834,7 +834,6 @@ def main(args): unet.set_attn_processor(unet_lora_attn_procs) unet_lora_layers = AttnProcsLayers(unet.attn_processors) - accelerator.register_for_checkpointing(unet_lora_layers) # The text encoder comes from 🤗 transformers, so we cannot directly modify it. # So, instead, we monkey-patch the forward calls of its attention-blocks. For this, @@ -853,9 +852,68 @@ def main(args): ) temp_pipeline._modify_text_encoder(text_lora_attn_procs) text_encoder = temp_pipeline.text_encoder - accelerator.register_for_checkpointing(text_encoder_lora_layers) del temp_pipeline + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + # there are only two options here. Either are just the unet attn processor layers + # or there are the unet and text encoder atten layers + unet_lora_layers_to_save = None + text_encoder_lora_layers_to_save = None + + if args.train_text_encoder: + text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys() + unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys() + + for model in models: + state_dict = model.state_dict() + + if ( + text_encoder_lora_layers is not None + and text_encoder_keys is not None + and state_dict.keys() == text_encoder_keys + ): + # text encoder + text_encoder_lora_layers_to_save = state_dict + elif state_dict.keys() == unet_keys: + # unet + unet_lora_layers_to_save = state_dict + + # make sure to pop weight so that corresponding model is not saved again + weights.pop() + + LoraLoaderMixin.save_lora_weights( + output_dir, + unet_lora_layers=unet_lora_layers_to_save, + text_encoder_lora_layers=text_encoder_lora_layers_to_save, + ) + + def load_model_hook(models, input_dir): + # Note we DON'T pass the unet and text encoder here an purpose + # so that the we don't accidentally override the LoRA layers of + # unet_lora_layers and text_encoder_lora_layers which are stored in `models` + # with new torch.nn.Modules / weights. We simply use the pipeline class as + # an easy way to load the lora checkpoints + temp_pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + revision=args.revision, + torch_dtype=weight_dtype, + ) + temp_pipeline.load_lora_weights(input_dir) + + # load lora weights into models + models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict()) + if len(models) > 1: + models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict()) + + # delete temporary pipeline and pop models + del temp_pipeline + for _ in range(len(models)): + models.pop() + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + # Enable TF32 for faster training on Ampere GPUs, # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices if args.allow_tf32: @@ -1130,17 +1188,10 @@ def compute_text_embeddings(prompt): progress_bar.update(1) global_step += 1 - if global_step % args.checkpointing_steps == 0: - if accelerator.is_main_process: + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") - # We combine the text encoder and UNet LoRA parameters with a simple - # custom logic. `accelerator.save_state()` won't know that. So, - # use `LoraLoaderMixin.save_lora_weights()`. - LoraLoaderMixin.save_lora_weights( - save_directory=save_path, - unet_lora_layers=unet_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, - ) + accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} @@ -1217,8 +1268,12 @@ def compute_text_embeddings(prompt): accelerator.wait_for_everyone() if accelerator.is_main_process: unet = unet.to(torch.float32) + unet_lora_layers = accelerator.unwrap_model(unet_lora_layers) + if text_encoder is not None: text_encoder = text_encoder.to(torch.float32) + text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers) + LoraLoaderMixin.save_lora_weights( save_directory=args.output_dir, unet_lora_layers=unet_lora_layers, @@ -1250,6 +1305,7 @@ def compute_text_embeddings(prompt): pipeline.load_lora_weights(args.output_dir) # run inference + images = [] if args.validation_prompt and args.num_validation_images > 0: generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7513fa2732ba..a1f0d8ec2a52 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -70,6 +70,9 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} + # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder + self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] + # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` def map_to(module, state_dict, *args, **kwargs): @@ -81,10 +84,19 @@ def map_to(module, state_dict, *args, **kwargs): return new_state_dict + def remap_key(key, state_dict): + for k in self.split_keys: + if k in key: + return key.split(k)[0] + k + + raise ValueError( + f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}." + ) + def map_from(module, state_dict, *args, **kwargs): all_keys = list(state_dict.keys()) for key in all_keys: - replace_key = key.split(".processor")[0] + ".processor" + replace_key = remap_key(key, state_dict) new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") state_dict[new_key] = state_dict[key] del state_dict[key] @@ -898,6 +910,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) self._modify_text_encoder(attn_procs_text_encoder) + # save lora attn procs of text encoder so that it can be easily retrieved + self._text_encoder_lora_attn_procs = attn_procs_text_encoder + # Otherwise, we're dealing with the old format. This means the `state_dict` should only # contain the module names of the `unet` as its keys WITHOUT any prefix. elif not all( @@ -907,6 +922,12 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di warn_message = "You have saved the LoRA weights using the old format. To convert the old LoRA weights to the new format, you can first load them in a dictionary and then create a new dictionary like the following: `new_state_dict = {f'unet'.{module_name}: params for module_name, params in old_state_dict.items()}`." warnings.warn(warn_message) + @property + def text_encoder_lora_attn_procs(self): + if hasattr(self, "_text_encoder_lora_attn_procs"): + return self._text_encoder_lora_attn_procs + return + def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): r""" Monkey-patches the forward passes of attention modules of the text encoder. @@ -1110,7 +1131,7 @@ def _load_text_encoder_attn_procs( def save_lora_weights( self, save_directory: Union[str, os.PathLike], - unet_lora_layers: Dict[str, torch.nn.Module] = None, + unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None, text_encoder_lora_layers: Dict[str, torch.nn.Module] = None, is_main_process: bool = True, weight_name: str = None, @@ -1123,13 +1144,14 @@ def save_lora_weights( Arguments: save_directory (`str` or `os.PathLike`): Directory to which to save. Will be created if it doesn't exist. - unet_lora_layers (`Dict[str, torch.nn.Module`]): + unet_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the UNet. Specifying this helps to make the - serialization process easier and cleaner. - text_encoder_lora_layers (`Dict[str, torch.nn.Module`]): + serialization process easier and cleaner. Values can be both LoRA torch.nn.Modules layers or torch + weights. + text_encoder_lora_layers (`Dict[str, torch.nn.Module] or `Dict[str, torch.Tensor]`): State dict of the LoRA layers corresponding to the `text_encoder`. Since the `text_encoder` comes from `transformers`, we cannot rejig it. That is why we have to explicitly pass the text encoder LoRA state - dict. + dict. Values can be both LoRA torch.nn.Modules layers or torch weights. is_main_process (`bool`, *optional*, defaults to `True`): Whether the process calling this is the main process or not. Useful when in distributed training like TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on @@ -1157,15 +1179,22 @@ def save_function(weights, filename): # Create a flat dictionary. state_dict = {} if unet_lora_layers is not None: - unet_lora_state_dict = { - f"{self.unet_name}.{module_name}": param - for module_name, param in unet_lora_layers.state_dict().items() - } + weights = ( + unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers + ) + + unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()} state_dict.update(unet_lora_state_dict) + if text_encoder_lora_layers is not None: + weights = ( + text_encoder_lora_layers.state_dict() + if isinstance(text_encoder_lora_layers, torch.nn.Module) + else text_encoder_lora_layers + ) + text_encoder_lora_state_dict = { - f"{self.text_encoder_name}.{module_name}": param - for module_name, param in text_encoder_lora_layers.state_dict().items() + f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items() } state_dict.update(text_encoder_lora_state_dict)