|
8 | 8 | from library import sai_model_spec, sdxl_model_util, train_util
|
9 | 9 | import library.model_util as model_util
|
10 | 10 | import lora
|
| 11 | +import oft |
11 | 12 | from library.utils import setup_logging
|
12 | 13 | setup_logging()
|
13 | 14 | import logging
|
14 | 15 | logger = logging.getLogger(__name__)
|
| 16 | +import concurrent.futures |
15 | 17 |
|
16 | 18 | def load_state_dict(file_name, dtype):
|
17 | 19 | if os.path.splitext(file_name)[1] == ".safetensors":
|
@@ -39,82 +41,176 @@ def save_to_file(file_name, model, state_dict, dtype, metadata):
|
39 | 41 | else:
|
40 | 42 | torch.save(model, file_name)
|
41 | 43 |
|
| 44 | +def detect_method_from_training_model(models, dtype): |
| 45 | + for model in models: |
| 46 | + lora_sd, _ = load_state_dict(model, dtype) |
| 47 | + for key in tqdm(lora_sd.keys()): |
| 48 | + if 'lora_up' in key or 'lora_down' in key: |
| 49 | + return 'LoRA' |
| 50 | + elif "oft_blocks" in key: |
| 51 | + return 'OFT' |
42 | 52 |
|
43 | 53 | def merge_to_sd_model(text_encoder1, text_encoder2, unet, models, ratios, merge_dtype):
|
44 | 54 | text_encoder1.to(merge_dtype)
|
45 | 55 | text_encoder1.to(merge_dtype)
|
46 | 56 | unet.to(merge_dtype)
|
| 57 | + |
| 58 | + # detect the method: OFT or LoRA_module |
| 59 | + method = detect_method_from_training_model(models, merge_dtype) |
| 60 | + logger.info(f"method:{method}") |
47 | 61 |
|
48 | 62 | # create module map
|
49 | 63 | name_to_module = {}
|
50 | 64 | for i, root_module in enumerate([text_encoder1, text_encoder2, unet]):
|
51 |
| - if i <= 1: |
52 |
| - if i == 0: |
53 |
| - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 |
| 65 | + if method == 'LoRA': |
| 66 | + if i <= 1: |
| 67 | + if i == 0: |
| 68 | + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER1 |
| 69 | + else: |
| 70 | + prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 |
| 71 | + target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE |
54 | 72 | else:
|
55 |
| - prefix = lora.LoRANetwork.LORA_PREFIX_TEXT_ENCODER2 |
56 |
| - target_replace_modules = lora.LoRANetwork.TEXT_ENCODER_TARGET_REPLACE_MODULE |
57 |
| - else: |
58 |
| - prefix = lora.LoRANetwork.LORA_PREFIX_UNET |
59 |
| - target_replace_modules = ( |
| 73 | + prefix = lora.LoRANetwork.LORA_PREFIX_UNET |
| 74 | + target_replace_modules = ( |
60 | 75 | lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE + lora.LoRANetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3
|
| 76 | + ) |
| 77 | + elif method == 'OFT': |
| 78 | + prefix = oft.OFTNetwork.OFT_PREFIX_UNET |
| 79 | + target_replace_modules = ( |
| 80 | + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_ALL_LINEAR + oft.OFTNetwork.UNET_TARGET_REPLACE_MODULE_CONV2D_3X3 |
61 | 81 | )
|
62 | 82 |
|
63 | 83 | for name, module in root_module.named_modules():
|
64 | 84 | if module.__class__.__name__ in target_replace_modules:
|
65 | 85 | for child_name, child_module in module.named_modules():
|
66 |
| - if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": |
67 |
| - lora_name = prefix + "." + name + "." + child_name |
68 |
| - lora_name = lora_name.replace(".", "_") |
69 |
| - name_to_module[lora_name] = child_module |
70 |
| - |
| 86 | + if method == 'LoRA': |
| 87 | + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": |
| 88 | + lora_name = prefix + "." + name + "." + child_name |
| 89 | + lora_name = lora_name.replace(".", "_") |
| 90 | + name_to_module[lora_name] = child_module |
| 91 | + elif method == 'OFT': |
| 92 | + if child_module.__class__.__name__ == "Linear" or child_module.__class__.__name__ == "Conv2d": |
| 93 | + oft_name = prefix + "." + name + "." + child_name |
| 94 | + oft_name = oft_name.replace(".", "_") |
| 95 | + name_to_module[oft_name] = child_module |
| 96 | + |
| 97 | + |
71 | 98 | for model, ratio in zip(models, ratios):
|
72 | 99 | logger.info(f"loading: {model}")
|
73 | 100 | lora_sd, _ = load_state_dict(model, merge_dtype)
|
74 | 101 |
|
75 | 102 | logger.info(f"merging...")
|
76 |
| - for key in tqdm(lora_sd.keys()): |
77 |
| - if "lora_down" in key: |
78 |
| - up_key = key.replace("lora_down", "lora_up") |
79 |
| - alpha_key = key[: key.index("lora_down")] + "alpha" |
80 | 103 |
|
81 |
| - # find original module for this lora |
82 |
| - module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" |
| 104 | + if method == 'LoRA': |
| 105 | + for key in tqdm(lora_sd.keys()): |
| 106 | + if "lora_down" in key: |
| 107 | + up_key = key.replace("lora_down", "lora_up") |
| 108 | + alpha_key = key[: key.index("lora_down")] + "alpha" |
| 109 | + |
| 110 | + # find original module for this lora |
| 111 | + module_name = ".".join(key.split(".")[:-2]) # remove trailing ".lora_down.weight" |
| 112 | + if module_name not in name_to_module: |
| 113 | + logger.info(f"no module found for LoRA weight: {key}") |
| 114 | + continue |
| 115 | + module = name_to_module[module_name] |
| 116 | + # logger.info(f"apply {key} to {module}") |
| 117 | + |
| 118 | + down_weight = lora_sd[key] |
| 119 | + up_weight = lora_sd[up_key] |
| 120 | + |
| 121 | + dim = down_weight.size()[0] |
| 122 | + alpha = lora_sd.get(alpha_key, dim) |
| 123 | + scale = alpha / dim |
| 124 | + |
| 125 | + # W <- W + U * D |
| 126 | + weight = module.weight |
| 127 | + # logger.info(module_name, down_weight.size(), up_weight.size()) |
| 128 | + if len(weight.size()) == 2: |
| 129 | + # linear |
| 130 | + weight = weight + ratio * (up_weight @ down_weight) * scale |
| 131 | + elif down_weight.size()[2:4] == (1, 1): |
| 132 | + # conv2d 1x1 |
| 133 | + weight = ( |
| 134 | + weight |
| 135 | + + ratio |
| 136 | + * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
| 137 | + * scale |
| 138 | + ) |
| 139 | + else: |
| 140 | + # conv2d 3x3 |
| 141 | + conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) |
| 142 | + # logger.info(conved.size(), weight.size(), module.stride, module.padding) |
| 143 | + weight = weight + ratio * conved * scale |
| 144 | + |
| 145 | + module.weight = torch.nn.Parameter(weight) |
| 146 | + |
| 147 | + |
| 148 | + elif method == 'OFT': |
| 149 | + |
| 150 | + multiplier=1.0 |
| 151 | + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| 152 | + |
| 153 | + for key in tqdm(lora_sd.keys()): |
| 154 | + if "oft_blocks" in key: |
| 155 | + oft_blocks = lora_sd[key] |
| 156 | + dim = oft_blocks.shape[0] |
| 157 | + break |
| 158 | + for key in tqdm(lora_sd.keys()): |
| 159 | + if "alpha" in key: |
| 160 | + oft_blocks = lora_sd[key] |
| 161 | + alpha = oft_blocks.item() |
| 162 | + break |
| 163 | + |
| 164 | + def merge_to(key): |
| 165 | + if "alpha" in key: |
| 166 | + return |
| 167 | + |
| 168 | + # find original module for this OFT |
| 169 | + module_name = ".".join(key.split(".")[:-1]) |
83 | 170 | if module_name not in name_to_module:
|
84 |
| - logger.info(f"no module found for LoRA weight: {key}") |
85 |
| - continue |
| 171 | + return |
86 | 172 | module = name_to_module[module_name]
|
87 |
| - # logger.info(f"apply {key} to {module}") |
88 | 173 |
|
89 |
| - down_weight = lora_sd[key] |
90 |
| - up_weight = lora_sd[up_key] |
91 |
| - |
92 |
| - dim = down_weight.size()[0] |
93 |
| - alpha = lora_sd.get(alpha_key, dim) |
94 |
| - scale = alpha / dim |
95 |
| - |
96 |
| - # W <- W + U * D |
97 |
| - weight = module.weight |
98 |
| - # logger.info(module_name, down_weight.size(), up_weight.size()) |
99 |
| - if len(weight.size()) == 2: |
100 |
| - # linear |
101 |
| - weight = weight + ratio * (up_weight @ down_weight) * scale |
102 |
| - elif down_weight.size()[2:4] == (1, 1): |
103 |
| - # conv2d 1x1 |
104 |
| - weight = ( |
105 |
| - weight |
106 |
| - + ratio |
107 |
| - * (up_weight.squeeze(3).squeeze(2) @ down_weight.squeeze(3).squeeze(2)).unsqueeze(2).unsqueeze(3) |
108 |
| - * scale |
109 |
| - ) |
| 174 | + # logger.info(f"apply {key} to {module}") |
| 175 | + |
| 176 | + oft_blocks = lora_sd[key] |
| 177 | + |
| 178 | + if isinstance(module, torch.nn.Linear): |
| 179 | + out_dim = module.out_features |
| 180 | + elif isinstance(module, torch.nn.Conv2d): |
| 181 | + out_dim = module.out_channels |
| 182 | + |
| 183 | + num_blocks = dim |
| 184 | + block_size = out_dim // dim |
| 185 | + constraint = (0 if alpha is None else alpha) * out_dim |
| 186 | + |
| 187 | + block_Q = oft_blocks - oft_blocks.transpose(1, 2) |
| 188 | + norm_Q = torch.norm(block_Q.flatten()) |
| 189 | + new_norm_Q = torch.clamp(norm_Q, max=constraint) |
| 190 | + block_Q = block_Q * ((new_norm_Q + 1e-8) / (norm_Q + 1e-8)) |
| 191 | + I = torch.eye(block_size, device=oft_blocks.device).unsqueeze(0).repeat(num_blocks, 1, 1) |
| 192 | + block_R = torch.matmul(I + block_Q, (I - block_Q).inverse()) |
| 193 | + block_R_weighted = multiplier * block_R + (1 - multiplier) * I |
| 194 | + R = torch.block_diag(*block_R_weighted) |
| 195 | + |
| 196 | + # get org weight |
| 197 | + org_sd = module.state_dict() |
| 198 | + org_weight = org_sd["weight"].to(device) |
| 199 | + |
| 200 | + R = R.to(org_weight.device, dtype=org_weight.dtype) |
| 201 | + |
| 202 | + if org_weight.dim() == 4: |
| 203 | + weight = torch.einsum("oihw, op -> pihw", org_weight, R) |
110 | 204 | else:
|
111 |
| - # conv2d 3x3 |
112 |
| - conved = torch.nn.functional.conv2d(down_weight.permute(1, 0, 2, 3), up_weight).permute(1, 0, 2, 3) |
113 |
| - # logger.info(conved.size(), weight.size(), module.stride, module.padding) |
114 |
| - weight = weight + ratio * conved * scale |
115 |
| - |
| 205 | + weight = torch.einsum("oi, op -> pi", org_weight, R) |
| 206 | + |
| 207 | + weight = weight.contiguous() # Make Tensor contiguous; required due to ThreadPoolExecutor |
| 208 | + |
116 | 209 | module.weight = torch.nn.Parameter(weight)
|
117 | 210 |
|
| 211 | + with concurrent.futures.ThreadPoolExecutor() as executor: |
| 212 | + list(tqdm(executor.map(merge_to, lora_sd.keys()), total=len(lora_sd.keys()))) |
| 213 | + |
118 | 214 |
|
119 | 215 | def merge_lora_models(models, ratios, merge_dtype, concat=False, shuffle=False):
|
120 | 216 | base_alphas = {} # alpha for merged model
|
|
0 commit comments