Skip to content

Support Kohya-ss style LoRA file format (in a limited capacity) #3437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
02111bc
add _convert_kohya_lora_to_diffusers
takuma104 May 14, 2023
7110e9a
make style
takuma104 May 14, 2023
21c5979
add scaffold
takuma104 May 15, 2023
8858ebb
match result: unet attention only
takuma104 May 15, 2023
bb9c61e
fix monkey-patch for text_encoder
takuma104 May 15, 2023
aa1d644
with CLIPAttention
takuma104 May 16, 2023
043da51
add to support network_alpha
takuma104 May 16, 2023
2317576
generate diff image
takuma104 May 16, 2023
8b6ac7b
Merge branch 'huggingface:main' into kohya-lora-loader
takuma104 May 17, 2023
fb708fb
fix monkey-patch for text_encoder
takuma104 May 15, 2023
6e8f3ab
add test_text_encoder_lora_monkey_patch()
takuma104 May 19, 2023
8511755
verify that it's okay to release the attn_procs
takuma104 May 19, 2023
81915f4
fix closure version
takuma104 May 19, 2023
88db546
add comment
takuma104 May 19, 2023
d22916e
Revert "fix monkey-patch for text_encoder"
takuma104 May 19, 2023
5c1024f
Merge branch 'text-encoder-lora-monkeypatch' into kohya-lora-loader
takuma104 May 19, 2023
1da772b
Fix to reuse utility functions
takuma104 May 22, 2023
8c0926c
Merge branch 'huggingface:main' into text-encoder-lora-monkeypatch
takuma104 May 22, 2023
8a26848
make LoRAAttnProcessor targets to self_attn
takuma104 May 22, 2023
28c69ee
fix LoRAAttnProcessor target
takuma104 May 22, 2023
3a74c7e
make style
takuma104 May 22, 2023
160a4d3
fix split key
takuma104 May 23, 2023
f14329d
Update src/diffusers/loaders.py
takuma104 May 23, 2023
6a17470
:wMerge branch 'text-encoder-lora-target' into kohya-lora-loader
takuma104 May 23, 2023
c3304f2
remove TEXT_ENCODER_TARGET_MODULES loop
takuma104 May 23, 2023
639171f
add print memory usage
takuma104 May 23, 2023
29ec4ca
remove test_kohya_loras_scaffold.py
takuma104 May 24, 2023
38d520b
add: doc on LoRA civitai
sayakpaul May 25, 2023
748dc67
remove print statement and refactor in the doc.
sayakpaul May 25, 2023
80e4b75
Merge branch 'main' into kohya-lora-loader
takuma104 May 29, 2023
08964d7
fix state_dict test for kohya-ss style lora
takuma104 May 30, 2023
23f12b7
Apply suggestions from code review
sayakpaul May 31, 2023
aaec8e1
Merge branch 'main' into kohya-lora-loader
sayakpaul May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -839,9 +839,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change here? Is this necessary? This makes this script be a bit out of sync with train_text_to_image_lora.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #3437 (comment). Note that this is only for the text encoder training part. For train_text_to_image_lora.py, we don't essentially train the text encoder anyway. So, I think it's okay.

text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
Expand Down
92 changes: 82 additions & 10 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]

# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
Expand Down Expand Up @@ -180,6 +180,7 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
network_alpha = kwargs.pop("network_alpha", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
Expand Down Expand Up @@ -282,7 +283,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
attn_processor_class = LoRAAttnProcessor

attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
Expand Down Expand Up @@ -887,6 +891,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
else:
state_dict = pretrained_model_name_or_path_or_dict

# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if any("alpha" in k for k in state_dict.keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is that the "safest" way to determine whether a checkpoint is in kohya-ss style or not? Do all checkpoints have a "alpha" parameter?

Copy link
Contributor

@StAlKeR7779 StAlKeR7779 May 27, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to conversion code - there also possible to check for "to_out_0.lora" in unet and "out_proj.lora" in text_encoder.
Also, for safetensors there possible to check by fields in metadata, but this not available when lora saved in pickle.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea of reading metadata from the safetensor seems reliable and good! Although an API is prepared for saving metadata of safetensor, ironically there is no API for loading the metadata, so we need to directly read the file, which feels a bit awkward. I posted the code and output examples here.
https://gist.github.com/takuma104/aaedd400e21b37d1e4297fe483e90695

I think checking 'ss_network_module' == 'networks.lora' could be one way.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I investigated the 97 LoRA files from CivitAI that I had on hand. 62 of them had metadata, and metadata['ss_network_module'] == 'networks.lora'. The remaining 35 did not have metadata. Metadata seems to be saved by default, but there is an option not to save it, which might be why it's not present. The precision is about 63%. It's a rather ambiguous number. Hmm.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The precision is about 63%.

Could you elaborate this a bit?

So, based on your investigation, is it safe to say that the current way of determining if the checkpoint is CivitAI is our best bet?

Cc: @patrickvonplaten

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The files I investigated were those downloaded previously from CivitAI, sorted in descending order of download count. I mistakenly wrote that there were a total of 97, but there were actually 100 files (there were about 3 cases where ['ss_network_module'] != 'networks.lora'). While it's not entirely certain that all of these files were created with the kohya-ss script, given that there seem to be a fair number of people using it, I think we can at least assume that it's a format supported by A1111.

As for whether the current tests are good, I tried with the following code and got these results.

available_alpha = 0
available_te_unet = 0

for filename in filenames:
    state_dict = safetensors.torch.load_file(filename)

    if any("alpha" in k for k in state_dict.keys()):
        available_alpha += 1
    
    if all((k.startswith('lora_te_') or k.startswith('lora_unet_')) for k in state_dict.keys()):
        available_te_unet += 1
available_alpha: 86
available_te_unet: 100

It seems that there are also files that do not contain network_alpha.
Therefore, I have changed the condition to be that the entirety of the key starts with either lora_te_ or lora_unet_. 08964d7

state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
Expand All @@ -898,7 +907,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)

# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
Expand All @@ -907,7 +916,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
self._modify_text_encoder(attn_procs_text_encoder)

# save lora attn procs of text encoder so that it can be easily retrieved
Expand Down Expand Up @@ -943,14 +954,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down Expand Up @@ -1037,6 +1054,7 @@ def _load_text_encoder_attn_procs(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
network_alpha = kwargs.pop("network_alpha", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
Expand Down Expand Up @@ -1114,7 +1132,10 @@ def _load_text_encoder_attn_procs(
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)

Expand Down Expand Up @@ -1208,6 +1229,57 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def _convert_kohya_lora_to_diffusers(self, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None

for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever and clean!

lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
if network_alpha is None:
network_alpha = alpha
elif network_alpha != alpha:
raise ValueError("Network alpha is not consistent")

if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]

unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**unet_state_dict, **te_state_dict}
print("converted", len(new_state_dict), "keys")
return new_state_dict, network_alpha


class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
Expand Down
43 changes: 25 additions & 18 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,14 +478,16 @@ def __call__(


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
Expand All @@ -497,21 +499,24 @@ def forward(self, hidden_states):
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

Comment on lines +534 to +536
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for me!

return up_hidden_states.to(orig_dtype)


class LoRAAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
Expand Down Expand Up @@ -750,19 +755,19 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a


class LoRAAttnAddedKVProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
Expand Down Expand Up @@ -943,18 +948,20 @@ def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, a


class LoRAXFormersAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
ONNX_EXTERNAL_WEIGHTS_NAME,
ONNX_WEIGHTS_NAME,
SAFETENSORS_WEIGHTS_NAME,
TEXT_ENCODER_ATTN_MODULE,
TEXT_ENCODER_TARGET_MODULES,
WEIGHTS_NAME,
)
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/utils/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,4 @@
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
TEXT_ENCODER_ATTN_MODULE = ".self_attn"
Loading