Skip to content

Support Kohya-ss style LoRA file format (in a limited capacity) #3437

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
Jun 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
02111bc
add _convert_kohya_lora_to_diffusers
takuma104 May 14, 2023
7110e9a
make style
takuma104 May 14, 2023
21c5979
add scaffold
takuma104 May 15, 2023
8858ebb
match result: unet attention only
takuma104 May 15, 2023
bb9c61e
fix monkey-patch for text_encoder
takuma104 May 15, 2023
aa1d644
with CLIPAttention
takuma104 May 16, 2023
043da51
add to support network_alpha
takuma104 May 16, 2023
2317576
generate diff image
takuma104 May 16, 2023
8b6ac7b
Merge branch 'huggingface:main' into kohya-lora-loader
takuma104 May 17, 2023
fb708fb
fix monkey-patch for text_encoder
takuma104 May 15, 2023
6e8f3ab
add test_text_encoder_lora_monkey_patch()
takuma104 May 19, 2023
8511755
verify that it's okay to release the attn_procs
takuma104 May 19, 2023
81915f4
fix closure version
takuma104 May 19, 2023
88db546
add comment
takuma104 May 19, 2023
d22916e
Revert "fix monkey-patch for text_encoder"
takuma104 May 19, 2023
5c1024f
Merge branch 'text-encoder-lora-monkeypatch' into kohya-lora-loader
takuma104 May 19, 2023
1da772b
Fix to reuse utility functions
takuma104 May 22, 2023
8c0926c
Merge branch 'huggingface:main' into text-encoder-lora-monkeypatch
takuma104 May 22, 2023
8a26848
make LoRAAttnProcessor targets to self_attn
takuma104 May 22, 2023
28c69ee
fix LoRAAttnProcessor target
takuma104 May 22, 2023
3a74c7e
make style
takuma104 May 22, 2023
160a4d3
fix split key
takuma104 May 23, 2023
f14329d
Update src/diffusers/loaders.py
takuma104 May 23, 2023
6a17470
:wMerge branch 'text-encoder-lora-target' into kohya-lora-loader
takuma104 May 23, 2023
c3304f2
remove TEXT_ENCODER_TARGET_MODULES loop
takuma104 May 23, 2023
639171f
add print memory usage
takuma104 May 23, 2023
29ec4ca
remove test_kohya_loras_scaffold.py
takuma104 May 24, 2023
38d520b
add: doc on LoRA civitai
sayakpaul May 25, 2023
748dc67
remove print statement and refactor in the doc.
sayakpaul May 25, 2023
80e4b75
Merge branch 'main' into kohya-lora-loader
takuma104 May 29, 2023
08964d7
fix state_dict test for kohya-ss style lora
takuma104 May 30, 2023
23f12b7
Apply suggestions from code review
sayakpaul May 31, 2023
aaec8e1
Merge branch 'main' into kohya-lora-loader
sayakpaul May 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 72 additions & 1 deletion docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -272,4 +272,75 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).

**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
refer to the respective docstrings.
refer to the respective docstrings.

## Supporting A1111 themed LoRA checkpoints from Diffusers

To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
in Diffusers and perform inference with it.

First, download a checkpoint. We'll use
[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes.

```bash
wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors
```

Next, we initialize a [`~DiffusionPipeline`]:

```python
import torch

from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler

pipeline = StableDiffusionPipeline.from_pretrained(
"gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
pipeline.scheduler.config, use_karras_sigmas=True
)
```

We then load the checkpoint downloaded from CivitAI:

```python
pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors")
```

<Tip warning={true}>

If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed.

</Tip>

And then it's time for running inference:

```python
prompt = "masterpiece, best quality, 1girl, at dusk"
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts")

images = pipeline(prompt=prompt,
negative_prompt=negative_prompt,
width=512,
height=768,
num_inference_steps=15,
num_images_per_prompt=4,
generator=torch.manual_seed(0)
).images
```

Below is a comparison between the LoRA and the non-LoRA results:

![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png)

You have a similar checkpoint stored on the Hugging Face Hub, you can load it
directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:

```python
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
lora_filename = "light_and_shadow.safetensors"
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
```
6 changes: 3 additions & 3 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
SlicedAttnAddedKVProcessor,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand Down Expand Up @@ -847,9 +847,9 @@ def main(args):
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the change here? Is this necessary? This makes this script be a bit out of sync with train_text_to_image_lora.py

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refer to #3437 (comment). Note that this is only for the text encoder training part. For train_text_to_image_lora.py, we don't essentially train the text encoder anyway. So, I think it's okay.

text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
hidden_size=module.out_proj.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = DiffusionPipeline.from_pretrained(
Expand Down
95 changes: 85 additions & 10 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
self.mapping = dict(enumerate(state_dict.keys()))
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}

# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
# .processor for unet, .self_attn for text encoder
self.split_keys = [".processor", ".self_attn"]

# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
Expand Down Expand Up @@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
network_alpha = kwargs.pop("network_alpha", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
Expand Down Expand Up @@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
attn_processor_class = LoRAAttnProcessor

attn_processors[key] = attn_processor_class(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)
elif is_custom_diffusion:
Expand Down Expand Up @@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di

<Tip warning={true}>

We support loading A1111 formatted LoRA checkpoints in a limited capacity.

This function is experimental and might change in the future.

</Tip>
Expand Down Expand Up @@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
else:
state_dict = pretrained_model_name_or_path_or_dict

# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
network_alpha = None
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this works for me!

state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)

# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
Expand All @@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
unet_lora_state_dict = {
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
}
self.unet.load_attn_procs(unet_lora_state_dict)
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)

# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
Expand All @@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
self._modify_text_encoder(attn_procs_text_encoder)

# save lora attn procs of text encoder so that it can be easily retrieved
Expand Down Expand Up @@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
module = self.text_encoder.get_submodule(name)
# Construct a new function that performs the LoRA merging. We will monkey patch
# this forward pass.
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
attn_processor_name = ".".join(name.split(".")[:-1])
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
old_forward = module.forward

def new_forward(x):
return old_forward(x) + lora_layer(x)
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
def make_new_forward(old_forward, lora_layer):
def new_forward(x):
return old_forward(x) + lora_layer(x)

return new_forward

# Monkey-patch.
module.forward = new_forward
module.forward = make_new_forward(old_forward, lora_layer)

def _get_lora_layer_attribute(self, name: str) -> str:
if "q_proj" in name:
Expand Down Expand Up @@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs(
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
network_alpha = kwargs.pop("network_alpha", None)

if use_safetensors and not is_safetensors_available():
raise ValueError(
Expand Down Expand Up @@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs(
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]

attn_processors[key] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
rank=rank,
network_alpha=network_alpha,
)
attn_processors[key].load_state_dict(value_dict)

Expand Down Expand Up @@ -1219,6 +1244,56 @@ def save_function(weights, filename):
save_function(state_dict, os.path.join(save_directory, weight_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")

def _convert_kohya_lora_to_diffusers(self, state_dict):
unet_state_dict = {}
te_state_dict = {}
network_alpha = None

for key, value in state_dict.items():
if "lora_down" in key:
lora_name = key.split(".")[0]
lora_name_up = lora_name + ".lora_up.weight"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clever and clean!

lora_name_alpha = lora_name + ".alpha"
if lora_name_alpha in state_dict:
alpha = state_dict[lora_name_alpha].item()
if network_alpha is None:
network_alpha = alpha
elif network_alpha != alpha:
raise ValueError("Network alpha is not consistent")

if lora_name.startswith("lora_unet_"):
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
if "transformer_blocks" in diffusers_name:
if "attn1" in diffusers_name or "attn2" in diffusers_name:
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
unet_state_dict[diffusers_name] = value
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
elif lora_name.startswith("lora_te_"):
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
diffusers_name = diffusers_name.replace("text.model", "text_model")
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
if "self_attn" in diffusers_name:
te_state_dict[diffusers_name] = value
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]

unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
new_state_dict = {**unet_state_dict, **te_state_dict}
return new_state_dict, network_alpha


class FromCkptMixin:
"""This helper class allows to directly load .ckpt stable diffusion file_extension
Expand Down
45 changes: 27 additions & 18 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,14 +508,18 @@ def __call__(


class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
super().__init__()

if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")

self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
self.network_alpha = network_alpha
self.rank = rank

nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
Expand All @@ -527,6 +531,9 @@ def forward(self, hidden_states):
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)

if self.network_alpha is not None:
up_hidden_states *= self.network_alpha / self.rank

Comment on lines +534 to +536
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay for me!

return up_hidden_states.to(orig_dtype)


Expand All @@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""

def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
Expand Down Expand Up @@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module):
The dimension of the LoRA update matrices.
"""

def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
residual = hidden_states
Expand Down Expand Up @@ -1157,18 +1164,20 @@ class LoRAXFormersAttnProcessor(nn.Module):
operator.
"""

def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
def __init__(
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
):
super().__init__()

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.rank = rank
self.attention_op = attention_op

self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)

def __call__(
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
Expand Down
Loading