From 02111bc37baa6b4aa7d96755b5028c2e02ae26c2 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Mon, 15 May 2023 02:53:33 +0900 Subject: [PATCH 01/27] add _convert_kohya_lora_to_diffusers --- src/diffusers/loaders.py | 54 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index a1f0d8ec2a52..71b7842cf316 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -887,6 +887,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict + if any('alpha' in k for k in state_dict.keys()): + state_dict = self._convert_kohya_lora_to_diffusers(state_dict) + # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as # their prefixes. @@ -1208,6 +1211,57 @@ def save_function(weights, filename): save_function(state_dict, os.path.join(save_directory, weight_name)) logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") + def _convert_kohya_lora_to_diffusers(self, state_dict): + unet_state_dict = {} + te_state_dict = {} + + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_dim = value.size()[0] + lora_name_up = lora_name + '.lora_up.weight' + lora_name_alpha = lora_name + '.alpha' + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim) + + if lora_name.startswith('lora_unet_'): + diffusers_name = key.replace('lora_unet_', '').replace('_', '.') + diffusers_name = diffusers_name.replace('down.blocks', 'down_blocks') + diffusers_name = diffusers_name.replace('mid.block', 'mid_block') + diffusers_name = diffusers_name.replace('up.blocks', 'up_blocks') + diffusers_name = diffusers_name.replace('transformer.blocks', 'transformer_blocks') + diffusers_name = diffusers_name.replace('to.q.lora', 'to_q_lora') + diffusers_name = diffusers_name.replace('to.k.lora', 'to_k_lora') + diffusers_name = diffusers_name.replace('to.v.lora', 'to_v_lora') + diffusers_name = diffusers_name.replace('to.out.0.lora', 'to_out_lora') + if 'transformer_blocks' in diffusers_name: + if 'attn1' in diffusers_name or 'attn2' in diffusers_name: + diffusers_name = diffusers_name.replace('attn1', 'attn1.processor') + diffusers_name = diffusers_name.replace('attn2', 'attn2.processor') + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace('.down.','.up.')] = state_dict[lora_name_up] + elif lora_name.startswith('lora_te_'): + diffusers_name = key.replace('lora_te_', '').replace('_', '.') + diffusers_name = diffusers_name.replace('text.model', 'text_model') + diffusers_name = diffusers_name.replace('self.attn', 'self_attn') + diffusers_name = diffusers_name.replace('q.proj.lora', 'to_q_lora') + diffusers_name = diffusers_name.replace('k.proj.lora', 'to_k_lora') + diffusers_name = diffusers_name.replace('v.proj.lora', 'to_v_lora') + diffusers_name = diffusers_name.replace('out.proj.lora', 'to_out_lora') + if 'self_attn' in diffusers_name: + prefix = '.'.join(diffusers_name.split('.')[:-3]) # e.g.: text_model.encoder.layers.0.self_attn + suffix = '.'.join(diffusers_name.split('.')[-3:]) # e.g.: to_k_lora.down.weight + for module_name in TEXT_ENCODER_TARGET_MODULES: + diffusers_name = f'{prefix}.{module_name}.{suffix}' + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace('.down.','.up.')] = state_dict[lora_name_up] + + unet_state_dict = {f'{UNET_NAME}.{module_name}': params for module_name, params in unet_state_dict.items()} + te_state_dict = {f'{TEXT_ENCODER_NAME}.{module_name}': params for module_name, params in te_state_dict.items()} + new_state_dict = {**unet_state_dict, **te_state_dict} + print('converted', len(new_state_dict), 'keys') + return new_state_dict class FromCkptMixin: """This helper class allows to directly load .ckpt stable diffusion file_extension From 7110e9a555c310ad946060454eee1d63f6945071 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Mon, 15 May 2023 02:54:14 +0900 Subject: [PATCH 02/27] make style --- src/diffusers/loaders.py | 75 +++++++++++++++++++++------------------- 1 file changed, 39 insertions(+), 36 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 71b7842cf316..63a5381e5898 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -887,7 +887,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict - if any('alpha' in k for k in state_dict.keys()): + if any("alpha" in k for k in state_dict.keys()): state_dict = self._convert_kohya_lora_to_diffusers(state_dict) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), @@ -1218,51 +1218,54 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): for key, value in state_dict.items(): if "lora_down" in key: lora_name = key.split(".")[0] - lora_dim = value.size()[0] - lora_name_up = lora_name + '.lora_up.weight' - lora_name_alpha = lora_name + '.alpha' + value.size()[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" if lora_name_alpha in state_dict: - alpha = state_dict[lora_name_alpha].item() + state_dict[lora_name_alpha].item() # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim) - if lora_name.startswith('lora_unet_'): - diffusers_name = key.replace('lora_unet_', '').replace('_', '.') - diffusers_name = diffusers_name.replace('down.blocks', 'down_blocks') - diffusers_name = diffusers_name.replace('mid.block', 'mid_block') - diffusers_name = diffusers_name.replace('up.blocks', 'up_blocks') - diffusers_name = diffusers_name.replace('transformer.blocks', 'transformer_blocks') - diffusers_name = diffusers_name.replace('to.q.lora', 'to_q_lora') - diffusers_name = diffusers_name.replace('to.k.lora', 'to_k_lora') - diffusers_name = diffusers_name.replace('to.v.lora', 'to_v_lora') - diffusers_name = diffusers_name.replace('to.out.0.lora', 'to_out_lora') - if 'transformer_blocks' in diffusers_name: - if 'attn1' in diffusers_name or 'attn2' in diffusers_name: - diffusers_name = diffusers_name.replace('attn1', 'attn1.processor') - diffusers_name = diffusers_name.replace('attn2', 'attn2.processor') + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace('.down.','.up.')] = state_dict[lora_name_up] - elif lora_name.startswith('lora_te_'): - diffusers_name = key.replace('lora_te_', '').replace('_', '.') - diffusers_name = diffusers_name.replace('text.model', 'text_model') - diffusers_name = diffusers_name.replace('self.attn', 'self_attn') - diffusers_name = diffusers_name.replace('q.proj.lora', 'to_q_lora') - diffusers_name = diffusers_name.replace('k.proj.lora', 'to_k_lora') - diffusers_name = diffusers_name.replace('v.proj.lora', 'to_v_lora') - diffusers_name = diffusers_name.replace('out.proj.lora', 'to_out_lora') - if 'self_attn' in diffusers_name: - prefix = '.'.join(diffusers_name.split('.')[:-3]) # e.g.: text_model.encoder.layers.0.self_attn - suffix = '.'.join(diffusers_name.split('.')[-3:]) # e.g.: to_k_lora.down.weight + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + elif lora_name.startswith("lora_te_"): + diffusers_name = key.replace("lora_te_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("text.model", "text_model") + diffusers_name = diffusers_name.replace("self.attn", "self_attn") + diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + if "self_attn" in diffusers_name: + prefix = ".".join( + diffusers_name.split(".")[:-3] + ) # e.g.: text_model.encoder.layers.0.self_attn + suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight for module_name in TEXT_ENCODER_TARGET_MODULES: - diffusers_name = f'{prefix}.{module_name}.{suffix}' + diffusers_name = f"{prefix}.{module_name}.{suffix}" te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace('.down.','.up.')] = state_dict[lora_name_up] + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - unet_state_dict = {f'{UNET_NAME}.{module_name}': params for module_name, params in unet_state_dict.items()} - te_state_dict = {f'{TEXT_ENCODER_NAME}.{module_name}': params for module_name, params in te_state_dict.items()} + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} - print('converted', len(new_state_dict), 'keys') + print("converted", len(new_state_dict), "keys") return new_state_dict + class FromCkptMixin: """This helper class allows to directly load .ckpt stable diffusion file_extension into the respective classes.""" From 21c5979790454497ff32ab4757cc21d45a65fcd5 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 00:31:55 +0900 Subject: [PATCH 03/27] add scaffold --- tests/test_kohya_loras_scaffold.py | 285 +++++++++++++++++++++++++++++ 1 file changed, 285 insertions(+) create mode 100644 tests/test_kohya_loras_scaffold.py diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py new file mode 100644 index 000000000000..f06b73da501a --- /dev/null +++ b/tests/test_kohya_loras_scaffold.py @@ -0,0 +1,285 @@ +# +# +# TODO: REMOVE THIS FILE +# This file is intended to be used for initial development of new features. +# +# + +import math + +import safetensors +import torch + +from diffusers import DiffusionPipeline + + +# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 +class LoRAModule(torch.nn.Module): + def __init__(self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0): + """if alpha == 0 or None, alpha is rank (no scaling).""" + super().__init__() + + if org_module.__class__.__name__ == "Conv2d": + in_dim = org_module.in_channels + out_dim = org_module.out_channels + else: + in_dim = org_module.in_features + out_dim = org_module.out_features + + self.lora_dim = lora_dim + + if org_module.__class__.__name__ == "Conv2d": + kernel_size = org_module.kernel_size + stride = org_module.stride + padding = org_module.padding + self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) + self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) + else: + self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) + self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) + + if alpha is None or alpha == 0: + self.alpha = self.lora_dim + else: + if type(alpha) == torch.Tensor: + alpha = alpha.detach().float().numpy() # without casting, bf16 causes error + self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. + + # same as microsoft's + torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) + torch.nn.init.zeros_(self.lora_up.weight) + + self.multiplier = multiplier + + def forward(self, x): + scale = self.alpha / self.lora_dim + return self.multiplier * scale * self.lora_up(self.lora_down(x)) + + +class LoRAModuleContainer(torch.nn.Module): + def __init__(self, hooks, state_dict, multiplier): + super().__init__() + self.multiplier = multiplier + + # Create LoRAModule from state_dict information + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + lora_dim = value.size()[0] + lora_name_alpha = key.split(".")[0] + ".alpha" + alpha = None + if lora_name_alpha in state_dict: + alpha = state_dict[lora_name_alpha].item() + hook = hooks[lora_name] + lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) + self.register_module(lora_name, lora_module) + + # Load whole LoRA weights + self.load_state_dict(state_dict) + + # Register LoRAModule to LoRAHook + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + hook = hooks[name] + hook.append_lora(module) + + @property + def alpha(self): + return self.multiplier + + @alpha.setter + def alpha(self, multiplier): + self.multiplier = multiplier + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + module.multiplier = multiplier + + def remove_from_hooks(self, hooks): + for name, module in self.named_modules(): + if module.__class__.__name__ == "LoRAModule": + hook = hooks[name] + hook.remove_lora(module) + del module + + +class LoRAHook(torch.nn.Module): + """ + replaces forward method of the original Linear, + instead of replacing the original Linear module. + """ + + def __init__(self): + super().__init__() + self.lora_modules = [] + + def install(self, orig_module): + assert not hasattr(self, "orig_module") + self.orig_module = orig_module + self.orig_forward = self.orig_module.forward + self.orig_module.forward = self.forward + + def uninstall(self): + assert hasattr(self, "orig_module") + self.orig_module.forward = self.orig_forward + del self.orig_forward + del self.orig_module + + def append_lora(self, lora_module): + self.lora_modules.append(lora_module) + + def remove_lora(self, lora_module): + self.lora_modules.remove(lora_module) + + def forward(self, x): + if len(self.lora_modules) == 0: + return self.orig_forward(x) + lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) + return self.orig_forward(x) + lora + + +class LoRAHookInjector(object): + def __init__(self): + super().__init__() + self.hooks = {} + self.device = None + self.dtype = None + + def _get_target_modules(self, root_module, prefix, target_replace_modules): + target_modules = [] + for name, module in root_module.named_modules(): + if ( + module.__class__.__name__ in target_replace_modules and "transformer_blocks" not in name + ): # to adapt latest diffusers: + for child_name, child_module in module.named_modules(): + is_linear = child_module.__class__.__name__ == "Linear" + is_conv2d = child_module.__class__.__name__ == "Conv2d" + if is_linear or is_conv2d: + lora_name = prefix + "." + name + "." + child_name + lora_name = lora_name.replace(".", "_") + target_modules.append((lora_name, child_module)) + return target_modules + + def install_hooks(self, pipe): + """Install LoRAHook to the pipe.""" + assert len(self.hooks) == 0 + text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) + unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) + for name, target_module in text_encoder_targets + unet_targets: + hook = LoRAHook() + hook.install(target_module) + self.hooks[name] = hook + + self.device = pipe.device + self.dtype = pipe.unet.dtype + + def uninstall_hooks(self): + """Uninstall LoRAHook from the pipe.""" + for k, v in self.hooks.items(): + v.uninstall() + self.hooks = {} + + def apply_lora(self, filename, alpha=1.0): + """Load LoRA weights and apply LoRA to the pipe.""" + assert len(self.hooks) != 0 + state_dict = safetensors.torch.load_file(filename) + container = LoRAModuleContainer(self.hooks, state_dict, alpha) + container.to(self.device, self.dtype) + return container + + def remove_lora(self, container): + """Remove the individual LoRA from the pipe.""" + container.remove_from_hooks(self.hooks) + + +def install_lora_hook(pipe: DiffusionPipeline): + """Install LoRAHook to the pipe.""" + assert not hasattr(pipe, "lora_injector") + assert not hasattr(pipe, "apply_lora") + assert not hasattr(pipe, "remove_lora") + injector = LoRAHookInjector() + injector.install_hooks(pipe) + pipe.lora_injector = injector + pipe.apply_lora = injector.apply_lora + pipe.remove_lora = injector.remove_lora + + +def uninstall_lora_hook(pipe: DiffusionPipeline): + """Uninstall LoRAHook from the pipe.""" + pipe.lora_injector.uninstall_hooks() + del pipe.lora_injector + del pipe.apply_lora + del pipe.remove_lora + + +from PIL import Image + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + grid_w, grid_h = grid.size + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +if __name__ == "__main__": + import torch + + from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline + + pipe = StableDiffusionPipeline.from_pretrained( + "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None + ).to("cuda") + pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) + pipe.enable_xformers_memory_efficient_attention() + + prompt = "masterpeace, best quality, highres, 1girl, at dusk" + negative_prompt = ( + "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2) " + ) + lora_fn = "../stable-diffusion-study/models/lora/light_and_shadow.safetensors" + + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.Generator(device="cuda").manual_seed(0), + ).images + image_grid(images, 1, 4).save("test_orig.png") + + install_lora_hook(pipe) + pipe.apply_lora(lora_fn) + + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.Generator(device="cuda").manual_seed(0), + ).images + image_grid(images, 1, 4).save("test_lora_hook.png") + + uninstall_lora_hook(pipe) + + pipe.load_lora_weights(lora_fn) + + images = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.Generator(device="cuda").manual_seed(0), + ).images + image_grid(images, 1, 4).save("test_lora_dev.png") From 8858ebb74624a1d929dfe08997a6deac288a27f4 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 01:14:42 +0900 Subject: [PATCH 04/27] match result: unet attention only --- tests/test_kohya_loras_scaffold.py | 108 +++++++++++++++++++++++------ 1 file changed, 85 insertions(+), 23 deletions(-) diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py index f06b73da501a..1fa9fa28e931 100644 --- a/tests/test_kohya_loras_scaffold.py +++ b/tests/test_kohya_loras_scaffold.py @@ -9,8 +9,9 @@ import safetensors import torch +from PIL import Image -from diffusers import DiffusionPipeline +from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline # modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 @@ -70,18 +71,20 @@ def __init__(self, hooks, state_dict, multiplier): alpha = None if lora_name_alpha in state_dict: alpha = state_dict[lora_name_alpha].item() - hook = hooks[lora_name] - lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) - self.register_module(lora_name, lora_module) + if lora_name in hooks: + hook = hooks[lora_name] + lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) + self.register_module(lora_name, lora_module) # Load whole LoRA weights - self.load_state_dict(state_dict) + self.load_state_dict(state_dict, strict=False) # Register LoRAModule to LoRAHook for name, module in self.named_modules(): if module.__class__.__name__ == "LoRAModule": - hook = hooks[name] - hook.append_lora(module) + if name in hooks: + hook = hooks[name] + hook.append_lora(module) @property def alpha(self): @@ -153,7 +156,8 @@ def _get_target_modules(self, root_module, prefix, target_replace_modules): for child_name, child_module in module.named_modules(): is_linear = child_module.__class__.__name__ == "Linear" is_conv2d = child_module.__class__.__name__ == "Conv2d" - if is_linear or is_conv2d: + # if is_linear or is_conv2d: + if is_linear and not is_conv2d and "ff.net" not in child_name: lora_name = prefix + "." + name + "." + child_name lora_name = lora_name.replace(".", "_") target_modules.append((lora_name, child_module)) @@ -162,12 +166,16 @@ def _get_target_modules(self, root_module, prefix, target_replace_modules): def install_hooks(self, pipe): """Install LoRAHook to the pipe.""" assert len(self.hooks) == 0 - text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) - unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) + # text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) + # unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) + text_encoder_targets = [] + unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel"]) + for name, target_module in text_encoder_targets + unet_targets: hook = LoRAHook() hook.install(target_module) self.hooks[name] = hook + print(name) self.device = pipe.device self.dtype = pipe.unet.dtype @@ -211,9 +219,6 @@ def uninstall_lora_hook(pipe: DiffusionPipeline): del pipe.remove_lora -from PIL import Image - - def image_grid(imgs, rows, cols): assert len(imgs) == rows * cols @@ -226,11 +231,68 @@ def image_grid(imgs, rows, cols): return grid -if __name__ == "__main__": - import torch +TEXT_ENCODER_NAME = "text_encoder" +UNET_NAME = "unet" + + +def convert_kohya_lora_to_diffusers(state_dict): + unet_state_dict = {} + te_state_dict = {} + + for key, value in state_dict.items(): + if "lora_down" in key: + lora_name = key.split(".")[0] + value.size()[0] + lora_name_up = lora_name + ".lora_up.weight" + lora_name_alpha = lora_name + ".alpha" + if lora_name_alpha in state_dict: + state_dict[lora_name_alpha].item() + # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim) + + if lora_name.startswith("lora_unet_"): + diffusers_name = key.replace("lora_unet_", "").replace("_", ".") + diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") + diffusers_name = diffusers_name.replace("mid.block", "mid_block") + diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") + diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") + diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") + diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") + diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") + diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") + if "transformer_blocks" in diffusers_name: + if "attn1" in diffusers_name or "attn2" in diffusers_name: + diffusers_name = diffusers_name.replace("attn1", "attn1.processor") + diffusers_name = diffusers_name.replace("attn2", "attn2.processor") + unet_state_dict[diffusers_name] = value + unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + # elif lora_name.startswith("lora_te_"): + # diffusers_name = key.replace("lora_te_", "").replace("_", ".") + # diffusers_name = diffusers_name.replace("text.model", "text_model") + # diffusers_name = diffusers_name.replace("self.attn", "self_attn") + # diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") + # diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") + # diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") + # diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") + # if "self_attn" in diffusers_name: + # prefix = ".".join( + # diffusers_name.split(".")[:-3] + # ) # e.g.: text_model.encoder.layers.0.self_attn + # suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight + # for module_name in TEXT_ENCODER_TARGET_MODULES: + # diffusers_name = f"{prefix}.{module_name}.{suffix}" + # te_state_dict[diffusers_name] = value + # te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + + unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} + te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} + new_state_dict = {**unet_state_dict, **te_state_dict} + print("converted", len(new_state_dict), "keys") + for k in sorted(new_state_dict.keys()): + print(k) + return new_state_dict - from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline +if __name__ == "__main__": pipe = StableDiffusionPipeline.from_pretrained( "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None ).to("cuda") @@ -251,13 +313,12 @@ def image_grid(imgs, rows, cols): height=768, num_inference_steps=15, num_images_per_prompt=4, - generator=torch.Generator(device="cuda").manual_seed(0), + generator=torch.manual_seed(0), ).images image_grid(images, 1, 4).save("test_orig.png") install_lora_hook(pipe) pipe.apply_lora(lora_fn) - images = pipe( prompt=prompt, negative_prompt=negative_prompt, @@ -265,14 +326,14 @@ def image_grid(imgs, rows, cols): height=768, num_inference_steps=15, num_images_per_prompt=4, - generator=torch.Generator(device="cuda").manual_seed(0), + generator=torch.manual_seed(0), ).images image_grid(images, 1, 4).save("test_lora_hook.png") - uninstall_lora_hook(pipe) - pipe.load_lora_weights(lora_fn) - + state_dict = safetensors.torch.load_file(lora_fn) + pipe.load_lora_weights(convert_kohya_lora_to_diffusers(state_dict)) + # pipe.load_lora_weights(lora_fn) images = pipe( prompt=prompt, negative_prompt=negative_prompt, @@ -280,6 +341,7 @@ def image_grid(imgs, rows, cols): height=768, num_inference_steps=15, num_images_per_prompt=4, - generator=torch.Generator(device="cuda").manual_seed(0), + generator=torch.manual_seed(0), + cross_attention_kwargs={"scale": 0.5}, # lora scale ).images image_grid(images, 1, 4).save("test_lora_dev.png") From bb9c61e6faecc1935c9c4319c77065837655d616 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 04:13:34 +0900 Subject: [PATCH 05/27] fix monkey-patch for text_encoder --- src/diffusers/loaders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 63a5381e5898..f591d0c381a0 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -946,14 +946,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + if name in attn_processors: + module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + module.old_forward = module.forward - # Monkey-patch. - module.forward = new_forward + def new_forward(self, x): + return self.old_forward(x) + self.lora_layer(x) + + # Monkey-patch. + module.forward = new_forward.__get__(module) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: From aa1d6446f4a5369b24234eb18f978dde65f21b27 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 17 May 2023 02:24:05 +0900 Subject: [PATCH 06/27] with CLIPAttention While the terrible images are no longer produced, the results do not match those from the hook ver. This may be due to not setting the network_alpha value. --- tests/test_kohya_loras_scaffold.py | 70 +++--------------------------- 1 file changed, 5 insertions(+), 65 deletions(-) diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py index 1fa9fa28e931..a385eec8517f 100644 --- a/tests/test_kohya_loras_scaffold.py +++ b/tests/test_kohya_loras_scaffold.py @@ -168,7 +168,7 @@ def install_hooks(self, pipe): assert len(self.hooks) == 0 # text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) # unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) - text_encoder_targets = [] + text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention"]) unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel"]) for name, target_module in text_encoder_targets + unet_targets: @@ -231,67 +231,6 @@ def image_grid(imgs, rows, cols): return grid -TEXT_ENCODER_NAME = "text_encoder" -UNET_NAME = "unet" - - -def convert_kohya_lora_to_diffusers(state_dict): - unet_state_dict = {} - te_state_dict = {} - - for key, value in state_dict.items(): - if "lora_down" in key: - lora_name = key.split(".")[0] - value.size()[0] - lora_name_up = lora_name + ".lora_up.weight" - lora_name_alpha = lora_name + ".alpha" - if lora_name_alpha in state_dict: - state_dict[lora_name_alpha].item() - # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim) - - if lora_name.startswith("lora_unet_"): - diffusers_name = key.replace("lora_unet_", "").replace("_", ".") - diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") - diffusers_name = diffusers_name.replace("mid.block", "mid_block") - diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") - diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") - diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") - diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") - diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") - diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") - if "transformer_blocks" in diffusers_name: - if "attn1" in diffusers_name or "attn2" in diffusers_name: - diffusers_name = diffusers_name.replace("attn1", "attn1.processor") - diffusers_name = diffusers_name.replace("attn2", "attn2.processor") - unet_state_dict[diffusers_name] = value - unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - # elif lora_name.startswith("lora_te_"): - # diffusers_name = key.replace("lora_te_", "").replace("_", ".") - # diffusers_name = diffusers_name.replace("text.model", "text_model") - # diffusers_name = diffusers_name.replace("self.attn", "self_attn") - # diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") - # diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") - # diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") - # diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") - # if "self_attn" in diffusers_name: - # prefix = ".".join( - # diffusers_name.split(".")[:-3] - # ) # e.g.: text_model.encoder.layers.0.self_attn - # suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight - # for module_name in TEXT_ENCODER_TARGET_MODULES: - # diffusers_name = f"{prefix}.{module_name}.{suffix}" - # te_state_dict[diffusers_name] = value - # te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] - - unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} - te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} - new_state_dict = {**unet_state_dict, **te_state_dict} - print("converted", len(new_state_dict), "keys") - for k in sorted(new_state_dict.keys()): - print(k) - return new_state_dict - - if __name__ == "__main__": pipe = StableDiffusionPipeline.from_pretrained( "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None @@ -306,6 +245,7 @@ def convert_kohya_lora_to_diffusers(state_dict): ) lora_fn = "../stable-diffusion-study/models/lora/light_and_shadow.safetensors" + # Without Lora images = pipe( prompt=prompt, negative_prompt=negative_prompt, @@ -317,6 +257,7 @@ def convert_kohya_lora_to_diffusers(state_dict): ).images image_grid(images, 1, 4).save("test_orig.png") + # Hook version (some restricted apply) install_lora_hook(pipe) pipe.apply_lora(lora_fn) images = pipe( @@ -331,9 +272,8 @@ def convert_kohya_lora_to_diffusers(state_dict): image_grid(images, 1, 4).save("test_lora_hook.png") uninstall_lora_hook(pipe) - state_dict = safetensors.torch.load_file(lora_fn) - pipe.load_lora_weights(convert_kohya_lora_to_diffusers(state_dict)) - # pipe.load_lora_weights(lora_fn) + # Diffusers dev version + pipe.load_lora_weights(lora_fn) images = pipe( prompt=prompt, negative_prompt=negative_prompt, From 043da515b5f61cbb6cafece137c8e5f698dc1678 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 17 May 2023 03:42:28 +0900 Subject: [PATCH 07/27] add to support network_alpha --- src/diffusers/loaders.py | 33 +++++++++++----- src/diffusers/models/attention_processor.py | 43 ++++++++++++--------- tests/test_kohya_loras_scaffold.py | 2 +- 3 files changed, 50 insertions(+), 28 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f591d0c381a0..1051e6280c4f 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -180,6 +180,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -282,7 +283,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict attn_processor_class = LoRAAttnProcessor attn_processors[key] = attn_processor_class( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) elif is_custom_diffusion: @@ -887,8 +891,10 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di else: state_dict = pretrained_model_name_or_path_or_dict + # Convert kohya-ss Style LoRA attn procs to diffusers attn procs + network_alpha = None if any("alpha" in k for k in state_dict.keys()): - state_dict = self._convert_kohya_lora_to_diffusers(state_dict) + state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), # then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as @@ -901,7 +907,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di unet_lora_state_dict = { k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys } - self.unet.load_attn_procs(unet_lora_state_dict) + self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) # Load the layers corresponding to text encoder and make necessary adjustments. text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] @@ -910,7 +916,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys } if len(text_encoder_lora_state_dict) > 0: - attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) + attn_procs_text_encoder = self._load_text_encoder_attn_procs( + text_encoder_lora_state_dict, network_alpha=network_alpha + ) self._modify_text_encoder(attn_procs_text_encoder) # save lora attn procs of text encoder so that it can be easily retrieved @@ -1042,6 +1050,7 @@ def _load_text_encoder_attn_procs( subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): raise ValueError( @@ -1119,7 +1128,10 @@ def _load_text_encoder_attn_procs( hidden_size = value_dict["to_k_lora.up.weight"].shape[0] attn_processors[key] = LoRAAttnProcessor( - hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + rank=rank, + network_alpha=network_alpha, ) attn_processors[key].load_state_dict(value_dict) @@ -1216,16 +1228,19 @@ def save_function(weights, filename): def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {} te_state_dict = {} + network_alpha = None for key, value in state_dict.items(): if "lora_down" in key: lora_name = key.split(".")[0] - value.size()[0] lora_name_up = lora_name + ".lora_up.weight" lora_name_alpha = lora_name + ".alpha" if lora_name_alpha in state_dict: - state_dict[lora_name_alpha].item() - # print(lora_name_alpha, alpha, lora_dim, alpha / lora_dim) + alpha = state_dict[lora_name_alpha].item() + if network_alpha is None: + network_alpha = alpha + elif network_alpha != alpha: + raise ValueError("Network alpha is not consistent") if lora_name.startswith("lora_unet_"): diffusers_name = key.replace("lora_unet_", "").replace("_", ".") @@ -1265,7 +1280,7 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} print("converted", len(new_state_dict), "keys") - return new_state_dict + return new_state_dict, network_alpha class FromCkptMixin: diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index f88400da0333..b0c88ee801dc 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -468,7 +468,7 @@ def __call__( class LoRALinearLayer(nn.Module): - def __init__(self, in_features, out_features, rank=4): + def __init__(self, in_features, out_features, rank=4, network_alpha=None): super().__init__() if rank > min(in_features, out_features): @@ -476,6 +476,8 @@ def __init__(self, in_features, out_features, rank=4): self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) + self.network_alpha = network_alpha + self.rank = rank nn.init.normal_(self.down.weight, std=1 / rank) nn.init.zeros_(self.up.weight) @@ -487,21 +489,24 @@ def forward(self, hidden_states): down_hidden_states = self.down(hidden_states.to(dtype)) up_hidden_states = self.up(down_hidden_states) + if self.network_alpha is not None: + up_hidden_states *= self.network_alpha / self.rank + return up_hidden_states.to(orig_dtype) class LoRAAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states @@ -740,19 +745,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class LoRAAttnAddedKVProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim=None, rank=4): + def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): super().__init__() self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.rank = rank - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states @@ -933,7 +938,9 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a class LoRAXFormersAttnProcessor(nn.Module): - def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): + def __init__( + self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None + ): super().__init__() self.hidden_size = hidden_size @@ -941,10 +948,10 @@ def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optio self.rank = rank self.attention_op = attention_op - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) + self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) + self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) + self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): residual = hidden_states diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py index a385eec8517f..c1628728a0ec 100644 --- a/tests/test_kohya_loras_scaffold.py +++ b/tests/test_kohya_loras_scaffold.py @@ -282,6 +282,6 @@ def image_grid(imgs, rows, cols): num_inference_steps=15, num_images_per_prompt=4, generator=torch.manual_seed(0), - cross_attention_kwargs={"scale": 0.5}, # lora scale + # cross_attention_kwargs={"scale": 0.5}, # lora scale ).images image_grid(images, 1, 4).save("test_lora_dev.png") From 23175769a5ca1421d9c608c7fe7c757b4362c302 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 17 May 2023 04:01:45 +0900 Subject: [PATCH 08/27] generate diff image --- tests/test_kohya_loras_scaffold.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py index c1628728a0ec..d20cb68b7929 100644 --- a/tests/test_kohya_loras_scaffold.py +++ b/tests/test_kohya_loras_scaffold.py @@ -7,6 +7,7 @@ import math +import numpy as np import safetensors import torch from PIL import Image @@ -285,3 +286,9 @@ def image_grid(imgs, rows, cols): # cross_attention_kwargs={"scale": 0.5}, # lora scale ).images image_grid(images, 1, 4).save("test_lora_dev.png") + + # abs-difference image + image_hook = np.array(Image.open("test_lora_hook.png"), dtype=np.int16) + image_dev = np.array(Image.open("test_lora_dev.png"), dtype=np.int16) + image_diff = Image.fromarray(np.abs(image_hook - image_dev).astype(np.uint8)) + image_diff.save("test_lora_hook_dev_diff.png") From fb708fba19276c08fdce5849447d725293cbf962 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 16 May 2023 04:13:34 +0900 Subject: [PATCH 09/27] fix monkey-patch for text_encoder --- src/diffusers/loaders.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index e50bc31a5c63..f9840bca626e 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,14 +943,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - old_forward = module.forward - def new_forward(x): - return old_forward(x) + lora_layer(x) + if name in attn_processors: + module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + module.old_forward = module.forward - # Monkey-patch. - module.forward = new_forward + def new_forward(self, x): + return self.old_forward(x) + self.lora_layer(x) + + # Monkey-patch. + module.forward = new_forward.__get__(module) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: From 6e8f3ab897a6c068b5ac997887cd79dbef6618d0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 00:29:14 +0900 Subject: [PATCH 10/27] add test_text_encoder_lora_monkey_patch() --- tests/models/test_lora_layers.py | 62 ++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6f1e85e15558..ffdc7569d2e1 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -212,3 +212,65 @@ def test_lora_save_load_legacy(self): # Outputs shouldn't match. self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice))) + + # copied from: https://colab.research.google.com/gist/sayakpaul/df2ef6e1ae6d8c10a49d859883b10860/scratchpad.ipynb + def get_dummy_tokens(self): + max_seq_length = 77 + + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + + prepared_inputs = {} + prepared_inputs["input_ids"] = inputs + return prepared_inputs + + def test_text_encoder_lora_monkey_patch(self): + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + + dummy_tokens = self.get_dummy_tokens() + + # inference without lora + outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_without_lora.shape == (1, 77, 768) + + text_lora_attn_procs = {} + for name, module in pipe.text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_features, cross_attention_dim=None + ).to("cuda") + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # make sure that the lora_up.weights are zeroed out + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_up_weight = lora_linear_layer.up.weight + assert torch.allclose( + lora_up_weight, torch.zeros_like(lora_up_weight) + ), "lora_up_weight should be zeroed out" + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" + + # make lora_up.weights as random + for name, attn_proc in text_lora_attn_procs.items(): + for n in ["q", "k", "v", "out"]: + n = f"to_{n}_lora" + lora_linear_layer = getattr(attn_proc, n) + lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + + # inference with lora + outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] + assert outputs_with_lora.shape == (1, 77, 768) + + assert not torch.allclose( + outputs_without_lora, outputs_with_lora + ), "lora_up_weight are not zero, so the lora outputs should be different to without lora outputs" From 851175565342669deeb59aa95e446f31b4b9b256 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 01:20:00 +0900 Subject: [PATCH 11/27] verify that it's okay to release the attn_procs --- tests/models/test_lora_layers.py | 42 +++++++++++++++++--------------- 1 file changed, 23 insertions(+), 19 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index ffdc7569d2e1..24043544a74d 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -12,6 +12,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import os import tempfile import unittest @@ -22,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor +from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -232,25 +233,27 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] assert outputs_without_lora.shape == (1, 77, 768) + # create lora_attn_procs with zeroed out up.weights text_lora_attn_procs = {} for name, module in pipe.text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None - ).to("cuda") + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") + + # make sure that the up.weights are zeroed out + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + assert torch.allclose( + layer_module.up.weight, torch.zeros_like(layer_module.up.weight) + ), "lora_up_weight should be zeroed out" + + text_lora_attn_procs[name] = attn_proc # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) - # make sure that the lora_up.weights are zeroed out - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_up_weight = lora_linear_layer.up.weight - assert torch.allclose( - lora_up_weight, torch.zeros_like(lora_up_weight) - ), "lora_up_weight should be zeroed out" + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] @@ -260,12 +263,13 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # make lora_up.weights as random - for name, attn_proc in text_lora_attn_procs.items(): - for n in ["q", "k", "v", "out"]: - n = f"to_{n}_lora" - lora_linear_layer = getattr(attn_proc, n) - lora_linear_layer.up.weight = torch.nn.Parameter(torch.randn_like(lora_linear_layer.up.weight)) + # set randn to lora_up.weights + for name, _ in pipe.text_encoder.named_modules(): + if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): + module = pipe.text_encoder.get_submodule(name) + assert hasattr(module, "lora_layer"), "lora_layer should be added" + assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" + module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 81915f48dfff3cd2e2654bc820088572f4e8f5db Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:47:58 +0900 Subject: [PATCH 12/27] fix closure version --- src/diffusers/loaders.py | 15 ++++----- tests/models/test_lora_layers.py | 53 ++++++++++++++++++-------------- 2 files changed, 38 insertions(+), 30 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index f9840bca626e..ad1096f65c21 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -943,16 +943,17 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward - if name in attn_processors: - module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - module.old_forward = module.forward + def make_new_forward(old_forward, lora_layer): + def new_forward(x): + return old_forward(x) + lora_layer(x) - def new_forward(self, x): - return self.old_forward(x) + self.lora_layer(x) + return new_forward - # Monkey-patch. - module.forward = new_forward.__get__(module) + # Monkey-patch. + module.forward = make_new_forward(old_forward, lora_layer) def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 24043544a74d..6cf79a0c11cb 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -23,7 +23,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin -from diffusers.models.attention_processor import LoRAAttnProcessor, LoRALinearLayer +from diffusers.models.attention_processor import LoRAAttnProcessor from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device @@ -218,14 +218,31 @@ def test_lora_save_load_legacy(self): def get_dummy_tokens(self): max_seq_length = 77 - inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)).to("cuda") + inputs = torch.randint(2, 56, size=(1, max_seq_length), generator=torch.manual_seed(0)) prepared_inputs = {} prepared_inputs["input_ids"] = inputs return prepared_inputs + def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): + text_lora_attn_procs = {} + for name, module in text_encoder.named_modules(): + if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + text_lora_attn_procs[name] = attn_proc + return text_lora_attn_procs + def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5").to("cuda") + pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") dummy_tokens = self.get_dummy_tokens() @@ -234,19 +251,7 @@ def test_text_encoder_lora_monkey_patch(self): assert outputs_without_lora.shape == (1, 77, 768) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = {} - for name, module in pipe.text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None).to("cuda") - - # make sure that the up.weights are zeroed out - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - assert torch.allclose( - layer_module.up.weight, torch.zeros_like(layer_module.up.weight) - ), "lora_up_weight should be zeroed out" - - text_lora_attn_procs[name] = attn_proc + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) # monkey patch pipe._modify_text_encoder(text_lora_attn_procs) @@ -263,13 +268,15 @@ def test_text_encoder_lora_monkey_patch(self): outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" - # set randn to lora_up.weights - for name, _ in pipe.text_encoder.named_modules(): - if any(name.endswith(x) for x in TEXT_ENCODER_TARGET_MODULES): - module = pipe.text_encoder.get_submodule(name) - assert hasattr(module, "lora_layer"), "lora_layer should be added" - assert isinstance(module.lora_layer, LoRALinearLayer), "lora_layer should be LoRALinearLayer" - module.lora_layer.up.weight = torch.nn.Parameter(torch.randn_like(module.lora_layer.up.weight)) + # create lora_attn_procs with randn up.weights + text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + + # monkey patch + pipe._modify_text_encoder(text_lora_attn_procs) + + # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. + del text_lora_attn_procs + gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] From 88db546c01eff271025ab1581f467f41be337c3f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 03:53:05 +0900 Subject: [PATCH 13/27] add comment --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index ad1096f65c21..7eb389184ed9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -946,6 +946,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) old_forward = module.forward + # create a new scope that locks in the old_forward, lora_layer value for each new_forward function def make_new_forward(old_forward, lora_layer): def new_forward(x): return old_forward(x) + lora_layer(x) From d22916ed95835807f92efe25fa724f8fff2fe6e2 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Sat, 20 May 2023 04:09:00 +0900 Subject: [PATCH 14/27] Revert "fix monkey-patch for text_encoder" This reverts commit bb9c61e6faecc1935c9c4319c77065837655d616. --- src/diffusers/loaders.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4cc87cd47995..c466231a6641 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -954,16 +954,14 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. + lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + old_forward = module.forward - if name in attn_processors: - module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) - module.old_forward = module.forward + def new_forward(x): + return old_forward(x) + lora_layer(x) - def new_forward(self, x): - return self.old_forward(x) + self.lora_layer(x) - - # Monkey-patch. - module.forward = new_forward.__get__(module) + # Monkey-patch. + module.forward = new_forward def _get_lora_layer_attribute(self, name: str) -> str: if "q_proj" in name: From 1da772b9fe8f64702989c1319348dfffd65dc491 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 00:02:01 +0900 Subject: [PATCH 15/27] Fix to reuse utility functions --- tests/models/test_lora_layers.py | 64 +++++++++++++++++--------------- 1 file changed, 34 insertions(+), 30 deletions(-) diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 6cf79a0c11cb..528c6e8bc35a 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -44,15 +44,33 @@ def create_unet_lora_layers(unet: nn.Module): return lora_attn_procs, unet_lora_layers -def create_text_encoder_lora_layers(text_encoder: nn.Module): +def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + return text_lora_attn_procs + + +def create_text_encoder_lora_layers(text_encoder: nn.Module): + text_lora_attn_procs = create_text_encoder_lora_attn_procs(text_encoder) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) return text_encoder_lora_layers +def set_lora_up_weights(text_lora_attn_procs, randn_weight=False): + for _, attn_proc in text_lora_attn_procs.items(): + # set up.weights + for layer_name, layer_module in attn_proc.named_modules(): + if layer_name.endswith("_lora"): + weight = ( + torch.randn_like(layer_module.up.weight) + if randn_weight + else torch.zeros_like(layer_module.up.weight) + ) + layer_module.up.weight = torch.nn.Parameter(weight) + + class LoraLoaderMixinTests(unittest.TestCase): def get_dummy_components(self): torch.manual_seed(0) @@ -224,63 +242,49 @@ def get_dummy_tokens(self): prepared_inputs["input_ids"] = inputs return prepared_inputs - def get_text_lora_attn_procs(self, text_encoder: nn.Module, randn_weight=False): - text_lora_attn_procs = {} - for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - attn_proc = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) - # set up.weights - for layer_name, layer_module in attn_proc.named_modules(): - if layer_name.endswith("_lora"): - weight = ( - torch.randn_like(layer_module.up.weight) - if randn_weight - else torch.zeros_like(layer_module.up.weight) - ) - layer_module.up.weight = torch.nn.Parameter(weight) - text_lora_attn_procs[name] = attn_proc - return text_lora_attn_procs - def test_text_encoder_lora_monkey_patch(self): - pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") + pipeline_components, _ = self.get_dummy_components() + pipe = StableDiffusionPipeline(**pipeline_components) dummy_tokens = self.get_dummy_tokens() # inference without lora outputs_without_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_without_lora.shape == (1, 77, 768) + assert outputs_without_lora.shape == (1, 77, 32) # create lora_attn_procs with zeroed out up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=False) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=False) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert torch.allclose( outputs_without_lora, outputs_with_lora ), "lora_up_weight are all zero, so the lora outputs should be the same to without lora outputs" # create lora_attn_procs with randn up.weights - text_lora_attn_procs = self.get_text_lora_attn_procs(pipe.text_encoder, randn_weight=True) + text_attn_procs = create_text_encoder_lora_attn_procs(pipe.text_encoder) + set_lora_up_weights(text_attn_procs, randn_weight=True) # monkey patch - pipe._modify_text_encoder(text_lora_attn_procs) + pipe._modify_text_encoder(text_attn_procs) - # verify that it's okay to release the text_lora_attn_procs which holds the LoRAAttnProcessor. - del text_lora_attn_procs + # verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor. + del text_attn_procs gc.collect() # inference with lora outputs_with_lora = pipe.text_encoder(**dummy_tokens)[0] - assert outputs_with_lora.shape == (1, 77, 768) + assert outputs_with_lora.shape == (1, 77, 32) assert not torch.allclose( outputs_without_lora, outputs_with_lora From 8a26848d62cc43b71706d1f7028de5771d35d760 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 00:47:17 +0900 Subject: [PATCH 16/27] make LoRAAttnProcessor targets to self_attn --- src/diffusers/loaders.py | 4 +++- src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/constants.py | 1 + tests/models/test_lora_layers.py | 8 +++++--- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 7eb389184ed9..5e9e96cbde0d 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,6 +33,7 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_TARGET_MODULES, + TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, @@ -943,7 +944,8 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): module = self.text_encoder.get_submodule(name) # Construct a new function that performs the LoRA merging. We will monkey patch # this forward pass. - lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) + attn_processor_name = ".".join(name.split(".")[:-1]) + lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index cd3a1b8f3dd4..772c36b1177b 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -30,6 +30,7 @@ ONNX_EXTERNAL_WEIGHTS_NAME, ONNX_WEIGHTS_NAME, SAFETENSORS_WEIGHTS_NAME, + TEXT_ENCODER_ATTN_MODULE, TEXT_ENCODER_TARGET_MODULES, WEIGHTS_NAME, ) diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 1134ba6fb656..93d5c8cc42cd 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -31,3 +31,4 @@ HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules")) DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"] TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"] +TEXT_ENCODER_ATTN_MODULE = ".self_attn" diff --git a/tests/models/test_lora_layers.py b/tests/models/test_lora_layers.py index 528c6e8bc35a..1c7e07744cd2 100644 --- a/tests/models/test_lora_layers.py +++ b/tests/models/test_lora_layers.py @@ -24,7 +24,7 @@ from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin from diffusers.models.attention_processor import LoRAAttnProcessor -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, floats_tensor, torch_device +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device def create_unet_lora_layers(unet: nn.Module): @@ -47,8 +47,10 @@ def create_unet_lora_layers(unet: nn.Module): def create_text_encoder_lora_attn_procs(text_encoder: nn.Module): text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): - text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=module.out_features, cross_attention_dim=None) + if name.endswith(TEXT_ENCODER_ATTN_MODULE): + text_lora_attn_procs[name] = LoRAAttnProcessor( + hidden_size=module.out_proj.out_features, cross_attention_dim=None + ) return text_lora_attn_procs From 28c69eefe7f80aced892a1b715f3f913d69c124f Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 01:13:10 +0900 Subject: [PATCH 17/27] fix LoRAAttnProcessor target --- examples/dreambooth/train_dreambooth_lora.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index e640542e36da..ceb360138f13 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -58,7 +58,7 @@ SlicedAttnAddedKVProcessor, ) from diffusers.optimization import get_scheduler -from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available +from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available from diffusers.utils.import_utils import is_xformers_available @@ -839,9 +839,9 @@ def main(args): if args.train_text_encoder: text_lora_attn_procs = {} for name, module in text_encoder.named_modules(): - if any(x in name for x in TEXT_ENCODER_TARGET_MODULES): + if name.endswith(TEXT_ENCODER_ATTN_MODULE): text_lora_attn_procs[name] = LoRAAttnProcessor( - hidden_size=module.out_features, cross_attention_dim=None + hidden_size=module.out_proj.out_features, cross_attention_dim=None ) text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs) temp_pipeline = StableDiffusionPipeline.from_pretrained( From 3a74c7e6d6496351a40cd47c028abde31244991b Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Tue, 23 May 2023 01:40:47 +0900 Subject: [PATCH 18/27] make style --- src/diffusers/loaders.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 5e9e96cbde0d..64a0e942fc77 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -33,7 +33,6 @@ DIFFUSERS_CACHE, HF_HUB_OFFLINE, TEXT_ENCODER_TARGET_MODULES, - TEXT_ENCODER_ATTN_MODULE, _get_model_file, deprecate, is_safetensors_available, From 160a4d356f2b08df4171d15641b1e38389178496 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 00:47:59 +0900 Subject: [PATCH 19/27] fix split key --- src/diffusers/loaders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 64a0e942fc77..6255ff89d5c9 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -70,8 +70,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): self.mapping = dict(enumerate(state_dict.keys())) self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - # .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder - self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] + # .processor for unet, .self_attn for text encoder + self.split_keys = [".processor", ".self_attn"] # we add a hook to state_dict() and load_state_dict() so that the # naming fits with `unet.attn_processors` From f14329d26351a3361afb16b42f4a0e34ab3da3c3 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 00:58:05 +0900 Subject: [PATCH 20/27] Update src/diffusers/loaders.py --- src/diffusers/loaders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 6255ff89d5c9..3a3db83f62da 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -948,6 +948,7 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): old_forward = module.forward # create a new scope that locks in the old_forward, lora_layer value for each new_forward function + # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 def make_new_forward(old_forward, lora_layer): def new_forward(x): return old_forward(x) + lora_layer(x) From c3304f27f856b888844ffec4cd738645dabe6518 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 01:29:56 +0900 Subject: [PATCH 21/27] remove TEXT_ENCODER_TARGET_MODULES loop --- src/diffusers/loaders.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 9248488ee9ec..1cf9ee097579 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1271,14 +1271,8 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") if "self_attn" in diffusers_name: - prefix = ".".join( - diffusers_name.split(".")[:-3] - ) # e.g.: text_model.encoder.layers.0.self_attn - suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight - for module_name in TEXT_ENCODER_TARGET_MODULES: - diffusers_name = f"{prefix}.{module_name}.{suffix}" - te_state_dict[diffusers_name] = value - te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] + te_state_dict[diffusers_name] = value + te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} From 639171fa1e0b3167e82a95da570ec1e56fbdcee0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 24 May 2023 01:50:05 +0900 Subject: [PATCH 22/27] add print memory usage --- tests/test_kohya_loras_scaffold.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py index d20cb68b7929..db319b180ae9 100644 --- a/tests/test_kohya_loras_scaffold.py +++ b/tests/test_kohya_loras_scaffold.py @@ -176,7 +176,7 @@ def install_hooks(self, pipe): hook = LoRAHook() hook.install(target_module) self.hooks[name] = hook - print(name) + # print(name) self.device = pipe.device self.dtype = pipe.unet.dtype @@ -233,6 +233,8 @@ def image_grid(imgs, rows, cols): if __name__ == "__main__": + torch.cuda.reset_peak_memory_stats() + pipe = StableDiffusionPipeline.from_pretrained( "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None ).to("cuda") @@ -258,6 +260,10 @@ def image_grid(imgs, rows, cols): ).images image_grid(images, 1, 4).save("test_orig.png") + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + print(f"Without Lora -> {mem_bytes/(10**6)}MB") + # Hook version (some restricted apply) install_lora_hook(pipe) pipe.apply_lora(lora_fn) @@ -273,6 +279,10 @@ def image_grid(imgs, rows, cols): image_grid(images, 1, 4).save("test_lora_hook.png") uninstall_lora_hook(pipe) + mem_bytes = torch.cuda.max_memory_allocated() + torch.cuda.reset_peak_memory_stats() + print(f"Hook version -> {mem_bytes/(10**6)}MB") + # Diffusers dev version pipe.load_lora_weights(lora_fn) images = pipe( @@ -287,6 +297,9 @@ def image_grid(imgs, rows, cols): ).images image_grid(images, 1, 4).save("test_lora_dev.png") + mem_bytes = torch.cuda.max_memory_allocated() + print(f"Diffusers dev version -> {mem_bytes/(10**6)}MB") + # abs-difference image image_hook = np.array(Image.open("test_lora_hook.png"), dtype=np.int16) image_dev = np.array(Image.open("test_lora_dev.png"), dtype=np.int16) From 29ec4ca8e7caf1600e52f50818930da18d24747e Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Thu, 25 May 2023 00:26:04 +0900 Subject: [PATCH 23/27] remove test_kohya_loras_scaffold.py --- tests/test_kohya_loras_scaffold.py | 307 ----------------------------- 1 file changed, 307 deletions(-) delete mode 100644 tests/test_kohya_loras_scaffold.py diff --git a/tests/test_kohya_loras_scaffold.py b/tests/test_kohya_loras_scaffold.py deleted file mode 100644 index db319b180ae9..000000000000 --- a/tests/test_kohya_loras_scaffold.py +++ /dev/null @@ -1,307 +0,0 @@ -# -# -# TODO: REMOVE THIS FILE -# This file is intended to be used for initial development of new features. -# -# - -import math - -import numpy as np -import safetensors -import torch -from PIL import Image - -from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler, StableDiffusionPipeline - - -# modified from https://github.com/kohya-ss/sd-scripts/blob/ad5f318d066c52e5b27306b399bc87e41f2eef2b/networks/lora.py#L17 -class LoRAModule(torch.nn.Module): - def __init__(self, org_module: torch.nn.Module, lora_dim=4, alpha=1.0, multiplier=1.0): - """if alpha == 0 or None, alpha is rank (no scaling).""" - super().__init__() - - if org_module.__class__.__name__ == "Conv2d": - in_dim = org_module.in_channels - out_dim = org_module.out_channels - else: - in_dim = org_module.in_features - out_dim = org_module.out_features - - self.lora_dim = lora_dim - - if org_module.__class__.__name__ == "Conv2d": - kernel_size = org_module.kernel_size - stride = org_module.stride - padding = org_module.padding - self.lora_down = torch.nn.Conv2d(in_dim, self.lora_dim, kernel_size, stride, padding, bias=False) - self.lora_up = torch.nn.Conv2d(self.lora_dim, out_dim, (1, 1), (1, 1), bias=False) - else: - self.lora_down = torch.nn.Linear(in_dim, self.lora_dim, bias=False) - self.lora_up = torch.nn.Linear(self.lora_dim, out_dim, bias=False) - - if alpha is None or alpha == 0: - self.alpha = self.lora_dim - else: - if type(alpha) == torch.Tensor: - alpha = alpha.detach().float().numpy() # without casting, bf16 causes error - self.register_buffer("alpha", torch.tensor(alpha)) # Treatable as a constant. - - # same as microsoft's - torch.nn.init.kaiming_uniform_(self.lora_down.weight, a=math.sqrt(5)) - torch.nn.init.zeros_(self.lora_up.weight) - - self.multiplier = multiplier - - def forward(self, x): - scale = self.alpha / self.lora_dim - return self.multiplier * scale * self.lora_up(self.lora_down(x)) - - -class LoRAModuleContainer(torch.nn.Module): - def __init__(self, hooks, state_dict, multiplier): - super().__init__() - self.multiplier = multiplier - - # Create LoRAModule from state_dict information - for key, value in state_dict.items(): - if "lora_down" in key: - lora_name = key.split(".")[0] - lora_dim = value.size()[0] - lora_name_alpha = key.split(".")[0] + ".alpha" - alpha = None - if lora_name_alpha in state_dict: - alpha = state_dict[lora_name_alpha].item() - if lora_name in hooks: - hook = hooks[lora_name] - lora_module = LoRAModule(hook.orig_module, lora_dim=lora_dim, alpha=alpha, multiplier=multiplier) - self.register_module(lora_name, lora_module) - - # Load whole LoRA weights - self.load_state_dict(state_dict, strict=False) - - # Register LoRAModule to LoRAHook - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - if name in hooks: - hook = hooks[name] - hook.append_lora(module) - - @property - def alpha(self): - return self.multiplier - - @alpha.setter - def alpha(self, multiplier): - self.multiplier = multiplier - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - module.multiplier = multiplier - - def remove_from_hooks(self, hooks): - for name, module in self.named_modules(): - if module.__class__.__name__ == "LoRAModule": - hook = hooks[name] - hook.remove_lora(module) - del module - - -class LoRAHook(torch.nn.Module): - """ - replaces forward method of the original Linear, - instead of replacing the original Linear module. - """ - - def __init__(self): - super().__init__() - self.lora_modules = [] - - def install(self, orig_module): - assert not hasattr(self, "orig_module") - self.orig_module = orig_module - self.orig_forward = self.orig_module.forward - self.orig_module.forward = self.forward - - def uninstall(self): - assert hasattr(self, "orig_module") - self.orig_module.forward = self.orig_forward - del self.orig_forward - del self.orig_module - - def append_lora(self, lora_module): - self.lora_modules.append(lora_module) - - def remove_lora(self, lora_module): - self.lora_modules.remove(lora_module) - - def forward(self, x): - if len(self.lora_modules) == 0: - return self.orig_forward(x) - lora = torch.sum(torch.stack([lora(x) for lora in self.lora_modules]), dim=0) - return self.orig_forward(x) + lora - - -class LoRAHookInjector(object): - def __init__(self): - super().__init__() - self.hooks = {} - self.device = None - self.dtype = None - - def _get_target_modules(self, root_module, prefix, target_replace_modules): - target_modules = [] - for name, module in root_module.named_modules(): - if ( - module.__class__.__name__ in target_replace_modules and "transformer_blocks" not in name - ): # to adapt latest diffusers: - for child_name, child_module in module.named_modules(): - is_linear = child_module.__class__.__name__ == "Linear" - is_conv2d = child_module.__class__.__name__ == "Conv2d" - # if is_linear or is_conv2d: - if is_linear and not is_conv2d and "ff.net" not in child_name: - lora_name = prefix + "." + name + "." + child_name - lora_name = lora_name.replace(".", "_") - target_modules.append((lora_name, child_module)) - return target_modules - - def install_hooks(self, pipe): - """Install LoRAHook to the pipe.""" - assert len(self.hooks) == 0 - # text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention", "CLIPMLP"]) - # unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel", "Attention"]) - text_encoder_targets = self._get_target_modules(pipe.text_encoder, "lora_te", ["CLIPAttention"]) - unet_targets = self._get_target_modules(pipe.unet, "lora_unet", ["Transformer2DModel"]) - - for name, target_module in text_encoder_targets + unet_targets: - hook = LoRAHook() - hook.install(target_module) - self.hooks[name] = hook - # print(name) - - self.device = pipe.device - self.dtype = pipe.unet.dtype - - def uninstall_hooks(self): - """Uninstall LoRAHook from the pipe.""" - for k, v in self.hooks.items(): - v.uninstall() - self.hooks = {} - - def apply_lora(self, filename, alpha=1.0): - """Load LoRA weights and apply LoRA to the pipe.""" - assert len(self.hooks) != 0 - state_dict = safetensors.torch.load_file(filename) - container = LoRAModuleContainer(self.hooks, state_dict, alpha) - container.to(self.device, self.dtype) - return container - - def remove_lora(self, container): - """Remove the individual LoRA from the pipe.""" - container.remove_from_hooks(self.hooks) - - -def install_lora_hook(pipe: DiffusionPipeline): - """Install LoRAHook to the pipe.""" - assert not hasattr(pipe, "lora_injector") - assert not hasattr(pipe, "apply_lora") - assert not hasattr(pipe, "remove_lora") - injector = LoRAHookInjector() - injector.install_hooks(pipe) - pipe.lora_injector = injector - pipe.apply_lora = injector.apply_lora - pipe.remove_lora = injector.remove_lora - - -def uninstall_lora_hook(pipe: DiffusionPipeline): - """Uninstall LoRAHook from the pipe.""" - pipe.lora_injector.uninstall_hooks() - del pipe.lora_injector - del pipe.apply_lora - del pipe.remove_lora - - -def image_grid(imgs, rows, cols): - assert len(imgs) == rows * cols - - w, h = imgs[0].size - grid = Image.new("RGB", size=(cols * w, rows * h)) - grid_w, grid_h = grid.size - - for i, img in enumerate(imgs): - grid.paste(img, box=(i % cols * w, i // cols * h)) - return grid - - -if __name__ == "__main__": - torch.cuda.reset_peak_memory_stats() - - pipe = StableDiffusionPipeline.from_pretrained( - "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None - ).to("cuda") - pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) - pipe.enable_xformers_memory_efficient_attention() - - prompt = "masterpeace, best quality, highres, 1girl, at dusk" - negative_prompt = ( - "(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " - "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2) " - ) - lora_fn = "../stable-diffusion-study/models/lora/light_and_shadow.safetensors" - - # Without Lora - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - ).images - image_grid(images, 1, 4).save("test_orig.png") - - mem_bytes = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - print(f"Without Lora -> {mem_bytes/(10**6)}MB") - - # Hook version (some restricted apply) - install_lora_hook(pipe) - pipe.apply_lora(lora_fn) - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - ).images - image_grid(images, 1, 4).save("test_lora_hook.png") - uninstall_lora_hook(pipe) - - mem_bytes = torch.cuda.max_memory_allocated() - torch.cuda.reset_peak_memory_stats() - print(f"Hook version -> {mem_bytes/(10**6)}MB") - - # Diffusers dev version - pipe.load_lora_weights(lora_fn) - images = pipe( - prompt=prompt, - negative_prompt=negative_prompt, - width=512, - height=768, - num_inference_steps=15, - num_images_per_prompt=4, - generator=torch.manual_seed(0), - # cross_attention_kwargs={"scale": 0.5}, # lora scale - ).images - image_grid(images, 1, 4).save("test_lora_dev.png") - - mem_bytes = torch.cuda.max_memory_allocated() - print(f"Diffusers dev version -> {mem_bytes/(10**6)}MB") - - # abs-difference image - image_hook = np.array(Image.open("test_lora_hook.png"), dtype=np.int16) - image_dev = np.array(Image.open("test_lora_dev.png"), dtype=np.int16) - image_diff = Image.fromarray(np.abs(image_hook - image_dev).astype(np.uint8)) - image_diff.save("test_lora_hook_dev_diff.png") From 38d520b702551c77bb9f5ea39de07f089af78649 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 25 May 2023 10:51:14 +0530 Subject: [PATCH 24/27] add: doc on LoRA civitai --- docs/source/en/training/lora.mdx | 67 +++++++++++++++++++++++++++++++- src/diffusers/loaders.py | 2 + 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index 748d99d5020d..e6cc52952c1c 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -272,4 +272,69 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is * LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth). **Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs, -refer to the respective docstrings. \ No newline at end of file +refer to the respective docstrings. + +## Supporting A1111 themed LoRA checkpoints from Diffusers + +To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted +LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity. +In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/) +in Diffusers and perform inference with it. + +First, download a checkpoint. We'll use +[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes. + +```bash +wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors +``` + +Next, we initialize a [`~DiffusionPipeline`]: + +```python +import torch + +from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler + +pipeline = StableDiffusionPipeline.from_pretrained( + "gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None +).to("cuda") +pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + pipeline.scheduler.config, use_karras_sigmas=True +) +``` + +We then load the checkpoint downloaded from CivitAI: + +```python +pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors") +``` + +And then it's time for running inference: + +```python +prompt = "masterpiece, best quality, 1girl, at dusk" +negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), " + "bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts") + +images = pipeline(prompt=prompt, + negative_prompt=negative_prompt, + width=512, + height=768, + num_inference_steps=15, + num_images_per_prompt=4, + generator=torch.manual_seed(0) +).images +``` + +Below is a comparison between the LoRA and the non-LoRA results: + +![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png) + +You have a similar checkpoint stored on the Hugging Face Hub, you can load it +directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so: + +```python +lora_model_id = "sayakpaul/civitai-light-shadow-lora" +lora_filename = "light_and_shadow.safetensors" +pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename) +``` \ No newline at end of file diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 1cf9ee097579..fb77b057b5c1 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -767,6 +767,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di + We support loading A1111 formatted LoRA checkpoints in a limited capacity. + This function is experimental and might change in the future. From 748dc678bc2ff21504099c4161ff92a09c0a22a3 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 25 May 2023 11:01:58 +0530 Subject: [PATCH 25/27] remove print statement and refactor in the doc. --- docs/source/en/training/lora.mdx | 6 ++++++ src/diffusers/loaders.py | 1 - 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/docs/source/en/training/lora.mdx b/docs/source/en/training/lora.mdx index e6cc52952c1c..484b08ce950a 100644 --- a/docs/source/en/training/lora.mdx +++ b/docs/source/en/training/lora.mdx @@ -309,6 +309,12 @@ We then load the checkpoint downloaded from CivitAI: pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors") ``` + + +If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed. + + + And then it's time for running inference: ```python diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index fb77b057b5c1..beb1c380d6cb 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -1279,7 +1279,6 @@ def _convert_kohya_lora_to_diffusers(self, state_dict): unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} new_state_dict = {**unet_state_dict, **te_state_dict} - print("converted", len(new_state_dict), "keys") return new_state_dict, network_alpha From 08964d7a62517c749324b3e7e807bb709db6caa0 Mon Sep 17 00:00:00 2001 From: Takuma Mori Date: Wed, 31 May 2023 01:07:00 +0900 Subject: [PATCH 26/27] fix state_dict test for kohya-ss style lora --- src/diffusers/loaders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 42195be88812..4cbf8995472c 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -901,7 +901,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di # Convert kohya-ss Style LoRA attn procs to diffusers attn procs network_alpha = None - if any("alpha" in k for k in state_dict.keys()): + if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) # If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), From 23f12b7edf800e4e26d70fc88e28224ac61a19a1 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 31 May 2023 07:37:09 +0530 Subject: [PATCH 27/27] Apply suggestions from code review Co-authored-by: Takuma Mori --- src/diffusers/loaders.py | 2 ++ src/diffusers/models/attention_processor.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/diffusers/loaders.py b/src/diffusers/loaders.py index 4cbf8995472c..aefe9b6d357b 100644 --- a/src/diffusers/loaders.py +++ b/src/diffusers/loaders.py @@ -182,6 +182,8 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict subfolder = kwargs.pop("subfolder", None) weight_name = kwargs.pop("weight_name", None) use_safetensors = kwargs.pop("use_safetensors", None) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning network_alpha = kwargs.pop("network_alpha", None) if use_safetensors and not is_safetensors_available(): diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 6d636a66e36c..5fa85062f643 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -501,6 +501,8 @@ def __init__(self, in_features, out_features, rank=4, network_alpha=None): self.down = nn.Linear(in_features, rank, bias=False) self.up = nn.Linear(rank, out_features, bias=False) + # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. + # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning self.network_alpha = network_alpha self.rank = rank