Skip to content

Commit 730f698

Browse files
authored
Support : OFT merge to base model (kohya-ss#1580)
* Support : OFT merge to base model * Fix typo * Fix typo_2 * Delete unused parameter 'eye'
1 parent 185bc5d commit 730f698

File tree

1 file changed

+144
-48
lines changed

1 file changed

+144
-48
lines changed

networks/sdxl_merge_lora.py

Lines changed: 144 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from library import sai_model_spec, sdxl_model_util, train_util
99
import library.model_util as model_util
1010
import lora
11+
import oft
1112
from library.utils import setup_logging
1213
setup_logging()
1314
import logging
1415
logger = logging.getLogger(__name__)
16+
import concurrent.futures
1517

1618
def load_state_dict(file_name, dtype):
1719
if os.path.splitext(file_name)[1] == ".safetensors":
@@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
3941
else:
4042
torch.save(model, file_name)
4143

44+
def detect_method_from_training_model(models, dtype):
45+
for model in models:
46+
lora_sd, _ = load_state_dict(model, dtype)
47+
for key in tqdm(lora_sd.keys()):
48+
if 'lora_up' in key or 'lora_down' in key:
49+
return 'LoRA'
50+
elif "oft_blocks" in key:
51+
return 'OFT'
4252

4353
def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
4454
text_encoder1.to(merge_dtype)
4555
text_encoder1.to(merge_dtype)
4656
unet.to(merge_dtype)
57+
58+
# detect the method: OFT or LoRA_module
59+
method = detect_method_from_training_model(models, merge_dtype)
60+
logger.info(f"method:{method}")
4761

4862
# create module map
4963
name_to_module = {}
5064
for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
51-
if i <= 1:
52-
if i == 0:
53-
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
65+
if method == 'LoRA':
66+
if i <= 1:
67+
if i == 0:
68+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1
69+
else:
70+
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
71+
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
5472
else:
55-
prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2
56-
target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE
57-
else:
58-
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
59-
target_replace_modules = (
73+
prefix = lora.LoRANetwork.LORA_PREFIX_UNET
74+
target_replace_modules = (
6075
lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
76+
)
77+
elif method == 'OFT':
78+
prefix = oft.OFTNetwork.OFT_PREFIX_UNET
79+
target_replace_modules = (
80+
oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
6181
)
6282

6383
for name, module in root_module.named_modules():
6484
if module.__class__.__name__ in target_replace_modules:
6585
for child_name, child_module in module.named_modules():
66-
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
67-
lora_name = prefix + "." + name + "." + child_name
68-
lora_name = lora_name.replace(".", "_")
69-
name_to_module[lora_name] = child_module
70-
86+
if method == 'LoRA':
87+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
88+
lora_name = prefix + "." + name + "." + child_name
89+
lora_name = lora_name.replace(".", "_")
90+
name_to_module[lora_name] = child_module
91+
elif method == 'OFT':
92+
if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d":
93+
oft_name = prefix + "." + name + "." + child_name
94+
oft_name = oft_name.replace(".", "_")
95+
name_to_module[oft_name] = child_module
96+
97+
7198
for model, ratio in zip(models, ratios):
7299
logger.info(f"loading: {model}")
73100
lora_sd, _ = load_state_dict(model, merge_dtype)
74101

75102
logger.info(f"merging...")
76-
for key in tqdm(lora_sd.keys()):
77-
if "lora_down" in key:
78-
up_key = key.replace("lora_down", "lora_up")
79-
alpha_key = key[: key.index("lora_down")] + "alpha"
80103

81-
# find original module for this lora
82-
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
104+
if method == 'LoRA':
105+
for key in tqdm(lora_sd.keys()):
106+
if "lora_down" in key:
107+
up_key = key.replace("lora_down", "lora_up")
108+
alpha_key = key[: key.index("lora_down")] + "alpha"
109+
110+
# find original module for this lora
111+
module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight"
112+
if module_name not in name_to_module:
113+
logger.info(f"no module found for LoRA weight: {key}")
114+
continue
115+
module = name_to_module[module_name]
116+
# logger.info(f"apply {key} to {module}")
117+
118+
down_weight = lora_sd[key]
119+
up_weight = lora_sd[up_key]
120+
121+
dim = down_weight.size()[0]
122+
alpha = lora_sd.get(alpha_key, dim)
123+
scale = alpha / dim
124+
125+
# W <- W + U * D
126+
weight = module.weight
127+
# logger.info(module_name, down_weight.size(), up_weight.size())
128+
if len(weight.size()) == 2:
129+
# linear
130+
weight = weight + ratio * (up_weight @ down_weight) * scale
131+
elif down_weight.size()[2:4] == (1, 1):
132+
# conv2d 1x1
133+
weight = (
134+
weight
135+
+ ratio
136+
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
137+
* scale
138+
)
139+
else:
140+
# conv2d 3x3
141+
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
142+
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
143+
weight = weight + ratio * conved * scale
144+
145+
module.weight = torch.nn.Parameter(weight)
146+
147+
148+
elif method == 'OFT':
149+
150+
multiplier=1.0
151+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
152+
153+
for key in tqdm(lora_sd.keys()):
154+
if "oft_blocks" in key:
155+
oft_blocks = lora_sd[key]
156+
dim = oft_blocks.shape[0]
157+
break
158+
for key in tqdm(lora_sd.keys()):
159+
if "alpha" in key:
160+
oft_blocks = lora_sd[key]
161+
alpha = oft_blocks.item()
162+
break
163+
164+
def merge_to(key):
165+
if "alpha" in key:
166+
return
167+
168+
# find original module for this OFT
169+
module_name = ".".join(key.split(".")[:-1])
83170
if module_name not in name_to_module:
84-
logger.info(f"no module found for LoRA weight: {key}")
85-
continue
171+
return
86172
module = name_to_module[module_name]
87-
# logger.info(f"apply {key} to {module}")
88173

89-
down_weight = lora_sd[key]
90-
up_weight = lora_sd[up_key]
91-
92-
dim = down_weight.size()[0]
93-
alpha = lora_sd.get(alpha_key, dim)
94-
scale = alpha / dim
95-
96-
# W <- W + U * D
97-
weight = module.weight
98-
# logger.info(module_name, down_weight.size(), up_weight.size())
99-
if len(weight.size()) == 2:
100-
# linear
101-
weight = weight + ratio * (up_weight @ down_weight) * scale
102-
elif down_weight.size()[2:4] == (1, 1):
103-
# conv2d 1x1
104-
weight = (
105-
weight
106-
+ ratio
107-
* (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3)
108-
* scale
109-
)
174+
# logger.info(f"apply {key} to {module}")
175+
176+
oft_blocks = lora_sd[key]
177+
178+
if isinstance(module, torch.nn.Linear):
179+
out_dim = module.out_features
180+
elif isinstance(module, torch.nn.Conv2d):
181+
out_dim = module.out_channels
182+
183+
num_blocks = dim
184+
block_size = out_dim // dim
185+
constraint = (0 if alpha is None else alpha) * out_dim
186+
187+
block_Q = oft_blocks - oft_blocks.transpose(1, 2)
188+
norm_Q = torch.norm(block_Q.flatten())
189+
new_norm_Q = torch.clamp(norm_Q, max=constraint)
190+
block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8))
191+
I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1)
192+
block_R = torch.matmul(I + block_Q, (I - block_Q).inverse())
193+
block_R_weighted = multiplier * block_R + (1 - multiplier) * I
194+
R = torch.block_diag(*block_R_weighted)
195+
196+
# get org weight
197+
org_sd = module.state_dict()
198+
org_weight = org_sd["weight"].to(device)
199+
200+
R = R.to(org_weight.device, dtype=org_weight.dtype)
201+
202+
if org_weight.dim() == 4:
203+
weight = torch.einsum("oihw, op -> pihw", org_weight, R)
110204
else:
111-
# conv2d 3x3
112-
conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3)
113-
# logger.info(conved.size(), weight.size(), module.stride, module.padding)
114-
weight = weight + ratio * conved * scale
115-
205+
weight = torch.einsum("oi, op -> pi", org_weight, R)
206+
207+
weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor
208+
116209
module.weight = torch.nn.Parameter(weight)
117210

211+
with concurrent.futures.ThreadPoolExecutor() as executor:
212+
list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys())))
213+
118214

119215
def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
120216
base_alphas = {} # alpha for merged model

0 commit comments

Comments
 (0)