-
Notifications
You must be signed in to change notification settings - Fork 6.2k
Support Kohya-ss style LoRA file format (in a limited capacity) #3437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
02111bc
7110e9a
21c5979
8858ebb
bb9c61e
aa1d644
043da51
2317576
8b6ac7b
fb708fb
6e8f3ab
8511755
81915f4
88db546
d22916e
5c1024f
1da772b
8c0926c
8a26848
28c69ee
3a74c7e
160a4d3
f14329d
6a17470
c3304f2
639171f
29ec4ca
38d520b
748dc67
80e4b75
08964d7
23f12b7
aaec8e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -180,6 +180,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
network_alpha = kwargs.pop("network_alpha", None) | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if use_safetensors and not is_safetensors_available(): | ||
raise ValueError( | ||
|
@@ -282,7 +283,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
attn_processor_class = LoRAAttnProcessor | ||
|
||
attn_processors[key] = attn_processor_class( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank | ||
hidden_size=hidden_size, | ||
cross_attention_dim=cross_attention_dim, | ||
rank=rank, | ||
network_alpha=network_alpha, | ||
) | ||
attn_processors[key].load_state_dict(value_dict) | ||
elif is_custom_diffusion: | ||
|
@@ -887,6 +891,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
else: | ||
state_dict = pretrained_model_name_or_path_or_dict | ||
|
||
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs | ||
network_alpha = None | ||
if any("alpha" in k for k in state_dict.keys()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is that the "safest" way to determine whether a checkpoint is in kohya-ss style or not? Do all checkpoints have a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to conversion code - there also possible to check for "to_out_0.lora" in unet and "out_proj.lora" in text_encoder. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea of reading metadata from the safetensor seems reliable and good! Although an API is prepared for saving metadata of safetensor, ironically there is no API for loading the metadata, so we need to directly read the file, which feels a bit awkward. I posted the code and output examples here. I think checking 'ss_network_module' == 'networks.lora' could be one way. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I investigated the 97 LoRA files from CivitAI that I had on hand. 62 of them had metadata, and There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Could you elaborate this a bit? So, based on your investigation, is it safe to say that the current way of determining if the checkpoint is CivitAI is our best bet? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The files I investigated were those downloaded previously from CivitAI, sorted in descending order of download count. I mistakenly wrote that there were a total of 97, but there were actually 100 files (there were about 3 cases where As for whether the current tests are good, I tried with the following code and got these results. available_alpha = 0
available_te_unet = 0
for filename in filenames:
state_dict = safetensors.torch.load_file(filename)
if any("alpha" in k for k in state_dict.keys()):
available_alpha += 1
if all((k.startswith('lora_te_') or k.startswith('lora_unet_')) for k in state_dict.keys()):
available_te_unet += 1
It seems that there are also files that do not contain network_alpha. |
||
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) | ||
|
||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | ||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as | ||
# their prefixes. | ||
|
@@ -898,7 +907,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
unet_lora_state_dict = { | ||
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys | ||
} | ||
self.unet.load_attn_procs(unet_lora_state_dict) | ||
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) | ||
|
||
# Load the layers corresponding to text encoder and make necessary adjustments. | ||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] | ||
|
@@ -907,7 +916,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys | ||
} | ||
if len(text_encoder_lora_state_dict) > 0: | ||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) | ||
attn_procs_text_encoder = self._load_text_encoder_attn_procs( | ||
text_encoder_lora_state_dict, network_alpha=network_alpha | ||
) | ||
self._modify_text_encoder(attn_procs_text_encoder) | ||
|
||
# save lora attn procs of text encoder so that it can be easily retrieved | ||
|
@@ -943,14 +954,16 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | |
module = self.text_encoder.get_submodule(name) | ||
# Construct a new function that performs the LoRA merging. We will monkey patch | ||
# this forward pass. | ||
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
old_forward = module.forward | ||
|
||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
if name in attn_processors: | ||
module.lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
module.old_forward = module.forward | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward | ||
def new_forward(self, x): | ||
return self.old_forward(x) + self.lora_layer(x) | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward.__get__(module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could I get a brief explanation summarizing the changes? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This change, which I believe should not be necessary in principle, is one that I would like to revert once I understand the underlying cause and confirm that it's not a problem. When I investigated the cause of the severe output distortion when enabling LoRA for the text_encoder, I found that the output was still distorted even when I removed the I compared the original code equivalent with the method of making it an instance method in simplified test code, but both of these tests pass without any problems. I suspect that it might be a memory-related issue with Python or PyTorch garbage collection, but I'm not sure yet. I'm thinking of adjusting the test a bit closer to the current situation and investigating. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Looking into it. |
||
|
||
def _get_lora_layer_attribute(self, name: str) -> str: | ||
if "q_proj" in name: | ||
|
@@ -1037,6 +1050,7 @@ def _load_text_encoder_attn_procs( | |
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
network_alpha = kwargs.pop("network_alpha", None) | ||
|
||
if use_safetensors and not is_safetensors_available(): | ||
raise ValueError( | ||
|
@@ -1114,7 +1128,10 @@ def _load_text_encoder_attn_procs( | |
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] | ||
|
||
attn_processors[key] = LoRAAttnProcessor( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank | ||
hidden_size=hidden_size, | ||
cross_attention_dim=cross_attention_dim, | ||
rank=rank, | ||
network_alpha=network_alpha, | ||
) | ||
attn_processors[key].load_state_dict(value_dict) | ||
|
||
|
@@ -1208,6 +1225,63 @@ def save_function(weights, filename): | |
save_function(state_dict, os.path.join(save_directory, weight_name)) | ||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") | ||
|
||
def _convert_kohya_lora_to_diffusers(self, state_dict): | ||
unet_state_dict = {} | ||
te_state_dict = {} | ||
network_alpha = None | ||
|
||
for key, value in state_dict.items(): | ||
if "lora_down" in key: | ||
lora_name = key.split(".")[0] | ||
lora_name_up = lora_name + ".lora_up.weight" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clever and clean! |
||
lora_name_alpha = lora_name + ".alpha" | ||
if lora_name_alpha in state_dict: | ||
alpha = state_dict[lora_name_alpha].item() | ||
if network_alpha is None: | ||
network_alpha = alpha | ||
elif network_alpha != alpha: | ||
raise ValueError("Network alpha is not consistent") | ||
|
||
if lora_name.startswith("lora_unet_"): | ||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") | ||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") | ||
diffusers_name = diffusers_name.replace("mid.block", "mid_block") | ||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") | ||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") | ||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") | ||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") | ||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") | ||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") | ||
if "transformer_blocks" in diffusers_name: | ||
if "attn1" in diffusers_name or "attn2" in diffusers_name: | ||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor") | ||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor") | ||
unet_state_dict[diffusers_name] = value | ||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
elif lora_name.startswith("lora_te_"): | ||
diffusers_name = key.replace("lora_te_", "").replace("_", ".") | ||
diffusers_name = diffusers_name.replace("text.model", "text_model") | ||
diffusers_name = diffusers_name.replace("self.attn", "self_attn") | ||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") | ||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") | ||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") | ||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") | ||
if "self_attn" in diffusers_name: | ||
prefix = ".".join( | ||
diffusers_name.split(".")[:-3] | ||
) # e.g.: text_model.encoder.layers.0.self_attn | ||
suffix = ".".join(diffusers_name.split(".")[-3:]) # e.g.: to_k_lora.down.weight | ||
for module_name in TEXT_ENCODER_TARGET_MODULES: | ||
diffusers_name = f"{prefix}.{module_name}.{suffix}" | ||
te_state_dict[diffusers_name] = value | ||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
|
||
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} | ||
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} | ||
new_state_dict = {**unet_state_dict, **te_state_dict} | ||
print("converted", len(new_state_dict), "keys") | ||
return new_state_dict, network_alpha | ||
|
||
|
||
class FromCkptMixin: | ||
"""This helper class allows to directly load .ckpt stable diffusion file_extension | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -468,14 +468,16 @@ def __call__( | |
|
||
|
||
class LoRALinearLayer(nn.Module): | ||
def __init__(self, in_features, out_features, rank=4): | ||
def __init__(self, in_features, out_features, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
if rank > min(in_features, out_features): | ||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
|
||
self.down = nn.Linear(in_features, rank, bias=False) | ||
self.up = nn.Linear(rank, out_features, bias=False) | ||
self.network_alpha = network_alpha | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.rank = rank | ||
|
||
nn.init.normal_(self.down.weight, std=1 / rank) | ||
nn.init.zeros_(self.up.weight) | ||
|
@@ -487,21 +489,24 @@ def forward(self, hidden_states): | |
down_hidden_states = self.down(hidden_states.to(dtype)) | ||
up_hidden_states = self.up(down_hidden_states) | ||
|
||
if self.network_alpha is not None: | ||
up_hidden_states *= self.network_alpha / self.rank | ||
|
||
Comment on lines
+534
to
+536
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay for me! |
||
return up_hidden_states.to(orig_dtype) | ||
|
||
|
||
class LoRAAttnProcessor(nn.Module): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
residual = hidden_states | ||
|
@@ -740,19 +745,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
|
||
class LoRAAttnAddedKVProcessor(nn.Module): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
residual = hidden_states | ||
|
@@ -933,18 +938,20 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a | |
|
||
|
||
class LoRAXFormersAttnProcessor(nn.Module): | ||
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): | ||
def __init__( | ||
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None | ||
): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
self.attention_op = attention_op | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
residual = hidden_states | ||
|
Uh oh!
There was an error while loading. Please reload this page.