|
1 | 1 | from typing import Union, List
|
2 | 2 |
|
| 3 | +import torch |
3 | 4 | from diffusers import StableDiffusionPipeline, AutoencoderKL, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
|
4 |
| -from tokenizers import Tokenizer |
5 |
| -from transformers import CLIPTokenizer |
| 5 | +from safetensors.torch import load_file |
6 | 6 |
|
7 | 7 |
|
8 | 8 | class TaiyiDrawer:
|
@@ -43,6 +43,78 @@ def load_textual_inversion(self, pretrained_model_name_or_path):
|
43 | 43 | self.__stable_diffusion.load_textual_inversion(pretrained_model_name_or_path)
|
44 | 44 | return self
|
45 | 45 |
|
| 46 | + def load_lora_weights(self, checkpoint_path): |
| 47 | + """ |
| 48 | + https://github.com/huggingface/diffusers/issues/3064#issuecomment-1510653978 |
| 49 | + USAGE: |
| 50 | + lora_model = lora_models + "/" + opt.lora + ".safetensors" |
| 51 | + self.pipe = load_lora_weights(self.pipe, lora_model) |
| 52 | + """ |
| 53 | + # First, load base model |
| 54 | + # You should USE CUDA. |
| 55 | + # self.__stable_diffusion.to("cuda") |
| 56 | + LORA_PREFIX_UNET = "lora_unet" |
| 57 | + LORA_PREFIX_TEXT_ENCODER = "lora_te" |
| 58 | + alpha = 0.75 |
| 59 | + # load LoRA weight from .safetensors |
| 60 | + state_dict = load_file(checkpoint_path, device="cuda") |
| 61 | + visited = [] |
| 62 | + |
| 63 | + # directly update weight in diffusers model |
| 64 | + for key in state_dict: |
| 65 | + # it is suggested to print out the key, it usually will be something like below |
| 66 | + # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight" |
| 67 | + |
| 68 | + # as we have set the alpha beforehand, so just skip |
| 69 | + if ".alpha" in key or key in visited: |
| 70 | + continue |
| 71 | + |
| 72 | + if "text" in key: |
| 73 | + layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_") |
| 74 | + curr_layer = self.__stable_diffusion.text_encoder |
| 75 | + else: |
| 76 | + layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_") |
| 77 | + curr_layer = self.__stable_diffusion.unet |
| 78 | + |
| 79 | + # find the target layer |
| 80 | + temp_name = layer_infos.pop(0) |
| 81 | + while len(layer_infos) > -1: |
| 82 | + try: |
| 83 | + curr_layer = curr_layer.__getattr__(temp_name) |
| 84 | + if len(layer_infos) > 0: |
| 85 | + temp_name = layer_infos.pop(0) |
| 86 | + elif len(layer_infos) == 0: |
| 87 | + break |
| 88 | + except Exception: |
| 89 | + if len(temp_name) > 0: |
| 90 | + temp_name += "_" + layer_infos.pop(0) |
| 91 | + else: |
| 92 | + temp_name = layer_infos.pop(0) |
| 93 | + |
| 94 | + pair_keys = [] |
| 95 | + if "lora_down" in key: |
| 96 | + pair_keys.append(key.replace("lora_down", "lora_up")) |
| 97 | + pair_keys.append(key) |
| 98 | + else: |
| 99 | + pair_keys.append(key) |
| 100 | + pair_keys.append(key.replace("lora_up", "lora_down")) |
| 101 | + |
| 102 | + # update weight |
| 103 | + if len(state_dict[pair_keys[0]].shape) == 4: |
| 104 | + weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32) |
| 105 | + weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32) |
| 106 | + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3) |
| 107 | + else: |
| 108 | + weight_up = state_dict[pair_keys[0]].to(torch.float32) |
| 109 | + weight_down = state_dict[pair_keys[1]].to(torch.float32) |
| 110 | + curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down) |
| 111 | + |
| 112 | + # update visited list |
| 113 | + for item in pair_keys: |
| 114 | + visited.append(item) |
| 115 | + |
| 116 | + return self |
| 117 | + |
46 | 118 | @staticmethod
|
47 | 119 | def build_dummy_safety_checker():
|
48 | 120 | return lambda images, clip_input: (images, False)
|
|
0 commit comments