-
Notifications
You must be signed in to change notification settings - Fork 6k
Support Kohya-ss style LoRA file format (in a limited capacity) #3437
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
02111bc
7110e9a
21c5979
8858ebb
bb9c61e
aa1d644
043da51
2317576
8b6ac7b
fb708fb
6e8f3ab
8511755
81915f4
88db546
d22916e
5c1024f
1da772b
8c0926c
8a26848
28c69ee
3a74c7e
160a4d3
f14329d
6a17470
c3304f2
639171f
29ec4ca
38d520b
748dc67
80e4b75
08964d7
23f12b7
aaec8e1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]): | |
self.mapping = dict(enumerate(state_dict.keys())) | ||
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} | ||
|
||
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder | ||
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"] | ||
# .processor for unet, .self_attn for text encoder | ||
self.split_keys = [".processor", ".self_attn"] | ||
|
||
# we add a hook to state_dict() and load_state_dict() so that the | ||
# naming fits with `unet.attn_processors` | ||
|
@@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. | ||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | ||
network_alpha = kwargs.pop("network_alpha", None) | ||
takuma104 marked this conversation as resolved.
Show resolved
Hide resolved
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if use_safetensors and not is_safetensors_available(): | ||
raise ValueError( | ||
|
@@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict | |
attn_processor_class = LoRAAttnProcessor | ||
|
||
attn_processors[key] = attn_processor_class( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank | ||
hidden_size=hidden_size, | ||
cross_attention_dim=cross_attention_dim, | ||
rank=rank, | ||
network_alpha=network_alpha, | ||
) | ||
attn_processors[key].load_state_dict(value_dict) | ||
elif is_custom_diffusion: | ||
|
@@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
|
||
<Tip warning={true}> | ||
|
||
We support loading A1111 formatted LoRA checkpoints in a limited capacity. | ||
|
||
This function is experimental and might change in the future. | ||
|
||
</Tip> | ||
|
@@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
else: | ||
state_dict = pretrained_model_name_or_path_or_dict | ||
|
||
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs | ||
network_alpha = None | ||
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this works for me! |
||
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict) | ||
|
||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918), | ||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as | ||
# their prefixes. | ||
|
@@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
unet_lora_state_dict = { | ||
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys | ||
} | ||
self.unet.load_attn_procs(unet_lora_state_dict) | ||
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha) | ||
|
||
# Load the layers corresponding to text encoder and make necessary adjustments. | ||
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)] | ||
|
@@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di | |
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys | ||
} | ||
if len(text_encoder_lora_state_dict) > 0: | ||
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict) | ||
attn_procs_text_encoder = self._load_text_encoder_attn_procs( | ||
text_encoder_lora_state_dict, network_alpha=network_alpha | ||
) | ||
self._modify_text_encoder(attn_procs_text_encoder) | ||
|
||
# save lora attn procs of text encoder so that it can be easily retrieved | ||
|
@@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]): | |
module = self.text_encoder.get_submodule(name) | ||
# Construct a new function that performs the LoRA merging. We will monkey patch | ||
# this forward pass. | ||
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name)) | ||
attn_processor_name = ".".join(name.split(".")[:-1]) | ||
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name)) | ||
old_forward = module.forward | ||
|
||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function | ||
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060 | ||
def make_new_forward(old_forward, lora_layer): | ||
def new_forward(x): | ||
return old_forward(x) + lora_layer(x) | ||
|
||
return new_forward | ||
|
||
# Monkey-patch. | ||
module.forward = new_forward | ||
module.forward = make_new_forward(old_forward, lora_layer) | ||
|
||
def _get_lora_layer_attribute(self, name: str) -> str: | ||
if "q_proj" in name: | ||
|
@@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs( | |
subfolder = kwargs.pop("subfolder", None) | ||
weight_name = kwargs.pop("weight_name", None) | ||
use_safetensors = kwargs.pop("use_safetensors", None) | ||
network_alpha = kwargs.pop("network_alpha", None) | ||
|
||
if use_safetensors and not is_safetensors_available(): | ||
raise ValueError( | ||
|
@@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs( | |
hidden_size = value_dict["to_k_lora.up.weight"].shape[0] | ||
|
||
attn_processors[key] = LoRAAttnProcessor( | ||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank | ||
hidden_size=hidden_size, | ||
cross_attention_dim=cross_attention_dim, | ||
rank=rank, | ||
network_alpha=network_alpha, | ||
) | ||
attn_processors[key].load_state_dict(value_dict) | ||
|
||
|
@@ -1219,6 +1244,56 @@ def save_function(weights, filename): | |
save_function(state_dict, os.path.join(save_directory, weight_name)) | ||
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}") | ||
|
||
def _convert_kohya_lora_to_diffusers(self, state_dict): | ||
unet_state_dict = {} | ||
te_state_dict = {} | ||
network_alpha = None | ||
|
||
for key, value in state_dict.items(): | ||
if "lora_down" in key: | ||
lora_name = key.split(".")[0] | ||
lora_name_up = lora_name + ".lora_up.weight" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Clever and clean! |
||
lora_name_alpha = lora_name + ".alpha" | ||
if lora_name_alpha in state_dict: | ||
alpha = state_dict[lora_name_alpha].item() | ||
if network_alpha is None: | ||
network_alpha = alpha | ||
elif network_alpha != alpha: | ||
raise ValueError("Network alpha is not consistent") | ||
|
||
if lora_name.startswith("lora_unet_"): | ||
diffusers_name = key.replace("lora_unet_", "").replace("_", ".") | ||
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks") | ||
diffusers_name = diffusers_name.replace("mid.block", "mid_block") | ||
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks") | ||
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks") | ||
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora") | ||
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora") | ||
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora") | ||
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora") | ||
if "transformer_blocks" in diffusers_name: | ||
if "attn1" in diffusers_name or "attn2" in diffusers_name: | ||
diffusers_name = diffusers_name.replace("attn1", "attn1.processor") | ||
diffusers_name = diffusers_name.replace("attn2", "attn2.processor") | ||
unet_state_dict[diffusers_name] = value | ||
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
elif lora_name.startswith("lora_te_"): | ||
diffusers_name = key.replace("lora_te_", "").replace("_", ".") | ||
diffusers_name = diffusers_name.replace("text.model", "text_model") | ||
diffusers_name = diffusers_name.replace("self.attn", "self_attn") | ||
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora") | ||
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora") | ||
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora") | ||
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora") | ||
if "self_attn" in diffusers_name: | ||
te_state_dict[diffusers_name] = value | ||
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up] | ||
|
||
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()} | ||
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()} | ||
new_state_dict = {**unet_state_dict, **te_state_dict} | ||
return new_state_dict, network_alpha | ||
|
||
|
||
class FromCkptMixin: | ||
"""This helper class allows to directly load .ckpt stable diffusion file_extension | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -508,14 +508,18 @@ def __call__( | |
|
||
|
||
class LoRALinearLayer(nn.Module): | ||
def __init__(self, in_features, out_features, rank=4): | ||
def __init__(self, in_features, out_features, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
if rank > min(in_features, out_features): | ||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}") | ||
|
||
self.down = nn.Linear(in_features, rank, bias=False) | ||
self.up = nn.Linear(rank, out_features, bias=False) | ||
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script. | ||
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning | ||
self.network_alpha = network_alpha | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
self.rank = rank | ||
|
||
nn.init.normal_(self.down.weight, std=1 / rank) | ||
nn.init.zeros_(self.up.weight) | ||
|
@@ -527,6 +531,9 @@ def forward(self, hidden_states): | |
down_hidden_states = self.down(hidden_states.to(dtype)) | ||
up_hidden_states = self.up(down_hidden_states) | ||
|
||
if self.network_alpha is not None: | ||
up_hidden_states *= self.network_alpha / self.rank | ||
|
||
Comment on lines
+534
to
+536
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Okay for me! |
||
return up_hidden_states.to(orig_dtype) | ||
|
||
|
||
|
@@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module): | |
The dimension of the LoRA update matrices. | ||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__( | ||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None | ||
|
@@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module): | |
The dimension of the LoRA update matrices. | ||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4): | ||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
residual = hidden_states | ||
|
@@ -1157,18 +1164,20 @@ class LoRAXFormersAttnProcessor(nn.Module): | |
operator. | ||
""" | ||
|
||
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None): | ||
def __init__( | ||
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None | ||
): | ||
super().__init__() | ||
|
||
self.hidden_size = hidden_size | ||
self.cross_attention_dim = cross_attention_dim | ||
self.rank = rank | ||
self.attention_op = attention_op | ||
|
||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank) | ||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) | ||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) | ||
|
||
def __call__( | ||
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why the change here? Is this necessary? This makes this script be a bit out of sync with
train_text_to_image_lora.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refer to #3437 (comment). Note that this is only for the text encoder training part. For
train_text_to_image_lora.py
, we don't essentially train the text encoder anyway. So, I think it's okay.