Skip to content

Commit a076b7a

Browse files
committed
https://github.com/huggingface/diffusers/issues/3064#issuecomment-1510653978
Signed-off-by: Sinri Edogawa <[email protected]>
1 parent 05cc8c7 commit a076b7a

File tree

1 file changed

+74
-2
lines changed

1 file changed

+74
-2
lines changed

taiyi/drawer/TaiyiDrawer.py

Lines changed: 74 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
from typing import Union, List
22

3+
import torch
34
from diffusers import StableDiffusionPipeline, AutoencoderKL, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
4-
from tokenizers import Tokenizer
5-
from transformers import CLIPTokenizer
5+
from safetensors.torch import load_file
66

77

88
class TaiyiDrawer:
@@ -43,6 +43,78 @@ def load_textual_inversion(self, pretrained_model_name_or_path):
4343
self.__stable_diffusion.load_textual_inversion(pretrained_model_name_or_path)
4444
return self
4545

46+
def load_lora_weights(self, checkpoint_path):
47+
"""
48+
https://github.com/huggingface/diffusers/issues/3064#issuecomment-1510653978
49+
USAGE:
50+
lora_model = lora_models + "/" + opt.lora + ".safetensors"
51+
self.pipe = load_lora_weights(self.pipe, lora_model)
52+
"""
53+
# First, load base model
54+
# You should USE CUDA.
55+
# self.__stable_diffusion.to("cuda")
56+
LORA_PREFIX_UNET = "lora_unet"
57+
LORA_PREFIX_TEXT_ENCODER = "lora_te"
58+
alpha = 0.75
59+
# load LoRA weight from .safetensors
60+
state_dict = load_file(checkpoint_path, device="cuda")
61+
visited = []
62+
63+
# directly update weight in diffusers model
64+
for key in state_dict:
65+
# it is suggested to print out the key, it usually will be something like below
66+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
67+
68+
# as we have set the alpha beforehand, so just skip
69+
if ".alpha" in key or key in visited:
70+
continue
71+
72+
if "text" in key:
73+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
74+
curr_layer = self.__stable_diffusion.text_encoder
75+
else:
76+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
77+
curr_layer = self.__stable_diffusion.unet
78+
79+
# find the target layer
80+
temp_name = layer_infos.pop(0)
81+
while len(layer_infos) > -1:
82+
try:
83+
curr_layer = curr_layer.__getattr__(temp_name)
84+
if len(layer_infos) > 0:
85+
temp_name = layer_infos.pop(0)
86+
elif len(layer_infos) == 0:
87+
break
88+
except Exception:
89+
if len(temp_name) > 0:
90+
temp_name += "_" + layer_infos.pop(0)
91+
else:
92+
temp_name = layer_infos.pop(0)
93+
94+
pair_keys = []
95+
if "lora_down" in key:
96+
pair_keys.append(key.replace("lora_down", "lora_up"))
97+
pair_keys.append(key)
98+
else:
99+
pair_keys.append(key)
100+
pair_keys.append(key.replace("lora_up", "lora_down"))
101+
102+
# update weight
103+
if len(state_dict[pair_keys[0]].shape) == 4:
104+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
105+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
106+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3)
107+
else:
108+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
109+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
110+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down)
111+
112+
# update visited list
113+
for item in pair_keys:
114+
visited.append(item)
115+
116+
return self
117+
46118
@staticmethod
47119
def build_dummy_safety_checker():
48120
return lambda images, clip_input: (images, False)

0 commit comments

Comments
 (0)