Skip to content

Commit 8e552bb

Browse files
takuma104sayakpaul
andauthored
Support Kohya-ss style LoRA file format (in a limited capacity) (#3437)
* add _convert_kohya_lora_to_diffusers * make style * add scaffold * match result: unet attention only * fix monkey-patch for text_encoder * with CLIPAttention While the terrible images are no longer produced, the results do not match those from the hook ver. This may be due to not setting the network_alpha value. * add to support network_alpha * generate diff image * fix monkey-patch for text_encoder * add test_text_encoder_lora_monkey_patch() * verify that it's okay to release the attn_procs * fix closure version * add comment * Revert "fix monkey-patch for text_encoder" This reverts commit bb9c61e. * Fix to reuse utility functions * make LoRAAttnProcessor targets to self_attn * fix LoRAAttnProcessor target * make style * fix split key * Update src/diffusers/loaders.py * remove TEXT_ENCODER_TARGET_MODULES loop * add print memory usage * remove test_kohya_loras_scaffold.py * add: doc on LoRA civitai * remove print statement and refactor in the doc. * fix state_dict test for kohya-ss style lora * Apply suggestions from code review Co-authored-by: Takuma Mori <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 32ea214 commit 8e552bb

File tree

7 files changed

+272
-36
lines changed

7 files changed

+272
-36
lines changed

docs/source/en/training/lora.mdx

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -272,4 +272,75 @@ Note that the use of [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] is
272272
* LoRA parameters that have separate identifiers for the UNet and the text encoder such as: [`"sayakpaul/dreambooth"`](https://huggingface.co/sayakpaul/dreambooth).
273273

274274
**Note** that it is possible to provide a local directory path to [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] as well as [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`]. To know about the supported inputs,
275-
refer to the respective docstrings.
275+
refer to the respective docstrings.
276+
277+
## Supporting A1111 themed LoRA checkpoints from Diffusers
278+
279+
To provide seamless interoperability with A1111 to our users, we support loading A1111 formatted
280+
LoRA checkpoints using [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] in a limited capacity.
281+
In this section, we explain how to load an A1111 formatted LoRA checkpoint from [CivitAI](https://civitai.com/)
282+
in Diffusers and perform inference with it.
283+
284+
First, download a checkpoint. We'll use
285+
[this one](https://civitai.com/models/13239/light-and-shadow) for demonstration purposes.
286+
287+
```bash
288+
wget https://civitai.com/api/download/models/15603 -O light_and_shadow.safetensors
289+
```
290+
291+
Next, we initialize a [`~DiffusionPipeline`]:
292+
293+
```python
294+
import torch
295+
296+
from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
297+
298+
pipeline = StableDiffusionPipeline.from_pretrained(
299+
"gsdf/Counterfeit-V2.5", torch_dtype=torch.float16, safety_checker=None
300+
).to("cuda")
301+
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
302+
pipeline.scheduler.config, use_karras_sigmas=True
303+
)
304+
```
305+
306+
We then load the checkpoint downloaded from CivitAI:
307+
308+
```python
309+
pipeline.load_lora_weights(".", weight_name="light_and_shadow.safetensors")
310+
```
311+
312+
<Tip warning={true}>
313+
314+
If you're loading a checkpoint in the `safetensors` format, please ensure you have `safetensors` installed.
315+
316+
</Tip>
317+
318+
And then it's time for running inference:
319+
320+
```python
321+
prompt = "masterpiece, best quality, 1girl, at dusk"
322+
negative_prompt = ("(low quality, worst quality:1.4), (bad anatomy), (inaccurate limb:1.2), "
323+
"bad composition, inaccurate eyes, extra digit, fewer digits, (extra arms:1.2), large breasts")
324+
325+
images = pipeline(prompt=prompt,
326+
negative_prompt=negative_prompt,
327+
width=512,
328+
height=768,
329+
num_inference_steps=15,
330+
num_images_per_prompt=4,
331+
generator=torch.manual_seed(0)
332+
).images
333+
```
334+
335+
Below is a comparison between the LoRA and the non-LoRA results:
336+
337+
![lora_non_lora](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/lora_non_lora_comparison.png)
338+
339+
You have a similar checkpoint stored on the Hugging Face Hub, you can load it
340+
directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
341+
342+
```python
343+
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
344+
lora_filename = "light_and_shadow.safetensors"
345+
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
346+
```

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
SlicedAttnAddedKVProcessor,
5959
)
6060
from diffusers.optimization import get_scheduler
61-
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
61+
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
6262
from diffusers.utils.import_utils import is_xformers_available
6363
from diffusers.utils.torch_utils import randn_tensor
6464

@@ -861,9 +861,9 @@ def main(args):
861861
if args.train_text_encoder:
862862
text_lora_attn_procs = {}
863863
for name, module in text_encoder.named_modules():
864-
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
864+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
865865
text_lora_attn_procs[name] = LoRAAttnProcessor(
866-
hidden_size=module.out_features, cross_attention_dim=None
866+
hidden_size=module.out_proj.out_features, cross_attention_dim=None
867867
)
868868
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
869869
temp_pipeline = DiffusionPipeline.from_pretrained(

src/diffusers/loaders.py

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,8 @@ def __init__(self, state_dict: Dict[str, torch.Tensor]):
7272
self.mapping = dict(enumerate(state_dict.keys()))
7373
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
7474

75-
# .processor for unet, .k_proj, ".q_proj", ".v_proj", and ".out_proj" for text encoder
76-
self.split_keys = [".processor", ".k_proj", ".q_proj", ".v_proj", ".out_proj"]
75+
# .processor for unet, .self_attn for text encoder
76+
self.split_keys = [".processor", ".self_attn"]
7777

7878
# we add a hook to state_dict() and load_state_dict() so that the
7979
# naming fits with `unet.attn_processors`
@@ -182,6 +182,9 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
182182
subfolder = kwargs.pop("subfolder", None)
183183
weight_name = kwargs.pop("weight_name", None)
184184
use_safetensors = kwargs.pop("use_safetensors", None)
185+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
186+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
187+
network_alpha = kwargs.pop("network_alpha", None)
185188

186189
if use_safetensors and not is_safetensors_available():
187190
raise ValueError(
@@ -287,7 +290,10 @@ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict
287290
attn_processor_class = LoRAAttnProcessor
288291

289292
attn_processors[key] = attn_processor_class(
290-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
293+
hidden_size=hidden_size,
294+
cross_attention_dim=cross_attention_dim,
295+
rank=rank,
296+
network_alpha=network_alpha,
291297
)
292298
attn_processors[key].load_state_dict(value_dict)
293299
elif is_custom_diffusion:
@@ -774,6 +780,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
774780
775781
<Tip warning={true}>
776782
783+
We support loading A1111 formatted LoRA checkpoints in a limited capacity.
784+
777785
This function is experimental and might change in the future.
778786
779787
</Tip>
@@ -898,6 +906,11 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
898906
else:
899907
state_dict = pretrained_model_name_or_path_or_dict
900908

909+
# Convert kohya-ss Style LoRA attn procs to diffusers attn procs
910+
network_alpha = None
911+
if all((k.startswith("lora_te_") or k.startswith("lora_unet_")) for k in state_dict.keys()):
912+
state_dict, network_alpha = self._convert_kohya_lora_to_diffusers(state_dict)
913+
901914
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
902915
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
903916
# their prefixes.
@@ -909,7 +922,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
909922
unet_lora_state_dict = {
910923
k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys
911924
}
912-
self.unet.load_attn_procs(unet_lora_state_dict)
925+
self.unet.load_attn_procs(unet_lora_state_dict, network_alpha=network_alpha)
913926

914927
# Load the layers corresponding to text encoder and make necessary adjustments.
915928
text_encoder_keys = [k for k in keys if k.startswith(self.text_encoder_name)]
@@ -918,7 +931,9 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
918931
k.replace(f"{self.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
919932
}
920933
if len(text_encoder_lora_state_dict) > 0:
921-
attn_procs_text_encoder = self._load_text_encoder_attn_procs(text_encoder_lora_state_dict)
934+
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
935+
text_encoder_lora_state_dict, network_alpha=network_alpha
936+
)
922937
self._modify_text_encoder(attn_procs_text_encoder)
923938

924939
# save lora attn procs of text encoder so that it can be easily retrieved
@@ -954,14 +969,20 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
954969
module = self.text_encoder.get_submodule(name)
955970
# Construct a new function that performs the LoRA merging. We will monkey patch
956971
# this forward pass.
957-
lora_layer = getattr(attn_processors[name], self._get_lora_layer_attribute(name))
972+
attn_processor_name = ".".join(name.split(".")[:-1])
973+
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
958974
old_forward = module.forward
959975

960-
def new_forward(x):
961-
return old_forward(x) + lora_layer(x)
976+
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
977+
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
978+
def make_new_forward(old_forward, lora_layer):
979+
def new_forward(x):
980+
return old_forward(x) + lora_layer(x)
981+
982+
return new_forward
962983

963984
# Monkey-patch.
964-
module.forward = new_forward
985+
module.forward = make_new_forward(old_forward, lora_layer)
965986

966987
def _get_lora_layer_attribute(self, name: str) -> str:
967988
if "q_proj" in name:
@@ -1048,6 +1069,7 @@ def _load_text_encoder_attn_procs(
10481069
subfolder = kwargs.pop("subfolder", None)
10491070
weight_name = kwargs.pop("weight_name", None)
10501071
use_safetensors = kwargs.pop("use_safetensors", None)
1072+
network_alpha = kwargs.pop("network_alpha", None)
10511073

10521074
if use_safetensors and not is_safetensors_available():
10531075
raise ValueError(
@@ -1125,7 +1147,10 @@ def _load_text_encoder_attn_procs(
11251147
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
11261148

11271149
attn_processors[key] = LoRAAttnProcessor(
1128-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
1150+
hidden_size=hidden_size,
1151+
cross_attention_dim=cross_attention_dim,
1152+
rank=rank,
1153+
network_alpha=network_alpha,
11291154
)
11301155
attn_processors[key].load_state_dict(value_dict)
11311156

@@ -1219,6 +1244,56 @@ def save_function(weights, filename):
12191244
save_function(state_dict, os.path.join(save_directory, weight_name))
12201245
logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
12211246

1247+
def _convert_kohya_lora_to_diffusers(self, state_dict):
1248+
unet_state_dict = {}
1249+
te_state_dict = {}
1250+
network_alpha = None
1251+
1252+
for key, value in state_dict.items():
1253+
if "lora_down" in key:
1254+
lora_name = key.split(".")[0]
1255+
lora_name_up = lora_name + ".lora_up.weight"
1256+
lora_name_alpha = lora_name + ".alpha"
1257+
if lora_name_alpha in state_dict:
1258+
alpha = state_dict[lora_name_alpha].item()
1259+
if network_alpha is None:
1260+
network_alpha = alpha
1261+
elif network_alpha != alpha:
1262+
raise ValueError("Network alpha is not consistent")
1263+
1264+
if lora_name.startswith("lora_unet_"):
1265+
diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
1266+
diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
1267+
diffusers_name = diffusers_name.replace("mid.block", "mid_block")
1268+
diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
1269+
diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
1270+
diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
1271+
diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
1272+
diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
1273+
diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
1274+
if "transformer_blocks" in diffusers_name:
1275+
if "attn1" in diffusers_name or "attn2" in diffusers_name:
1276+
diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
1277+
diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
1278+
unet_state_dict[diffusers_name] = value
1279+
unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1280+
elif lora_name.startswith("lora_te_"):
1281+
diffusers_name = key.replace("lora_te_", "").replace("_", ".")
1282+
diffusers_name = diffusers_name.replace("text.model", "text_model")
1283+
diffusers_name = diffusers_name.replace("self.attn", "self_attn")
1284+
diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
1285+
diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
1286+
diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
1287+
diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
1288+
if "self_attn" in diffusers_name:
1289+
te_state_dict[diffusers_name] = value
1290+
te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict[lora_name_up]
1291+
1292+
unet_state_dict = {f"{UNET_NAME}.{module_name}": params for module_name, params in unet_state_dict.items()}
1293+
te_state_dict = {f"{TEXT_ENCODER_NAME}.{module_name}": params for module_name, params in te_state_dict.items()}
1294+
new_state_dict = {**unet_state_dict, **te_state_dict}
1295+
return new_state_dict, network_alpha
1296+
12221297

12231298
class FromCkptMixin:
12241299
"""This helper class allows to directly load .ckpt stable diffusion file_extension

src/diffusers/models/attention_processor.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -508,14 +508,18 @@ def __call__(
508508

509509

510510
class LoRALinearLayer(nn.Module):
511-
def __init__(self, in_features, out_features, rank=4):
511+
def __init__(self, in_features, out_features, rank=4, network_alpha=None):
512512
super().__init__()
513513

514514
if rank > min(in_features, out_features):
515515
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
516516

517517
self.down = nn.Linear(in_features, rank, bias=False)
518518
self.up = nn.Linear(rank, out_features, bias=False)
519+
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
520+
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
521+
self.network_alpha = network_alpha
522+
self.rank = rank
519523

520524
nn.init.normal_(self.down.weight, std=1 / rank)
521525
nn.init.zeros_(self.up.weight)
@@ -527,6 +531,9 @@ def forward(self, hidden_states):
527531
down_hidden_states = self.down(hidden_states.to(dtype))
528532
up_hidden_states = self.up(down_hidden_states)
529533

534+
if self.network_alpha is not None:
535+
up_hidden_states *= self.network_alpha / self.rank
536+
530537
return up_hidden_states.to(orig_dtype)
531538

532539

@@ -543,17 +550,17 @@ class LoRAAttnProcessor(nn.Module):
543550
The dimension of the LoRA update matrices.
544551
"""
545552

546-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
553+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
547554
super().__init__()
548555

549556
self.hidden_size = hidden_size
550557
self.cross_attention_dim = cross_attention_dim
551558
self.rank = rank
552559

553-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
554-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
555-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
556-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
560+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
561+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
562+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
563+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
557564

558565
def __call__(
559566
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None
@@ -838,19 +845,19 @@ class LoRAAttnAddedKVProcessor(nn.Module):
838845
The dimension of the LoRA update matrices.
839846
"""
840847

841-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
848+
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
842849
super().__init__()
843850

844851
self.hidden_size = hidden_size
845852
self.cross_attention_dim = cross_attention_dim
846853
self.rank = rank
847854

848-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
849-
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
850-
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
851-
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
852-
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
853-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
855+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
856+
self.add_k_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
857+
self.add_v_proj_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
858+
self.to_k_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
859+
self.to_v_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
860+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
854861

855862
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0):
856863
residual = hidden_states
@@ -1157,18 +1164,20 @@ class LoRAXFormersAttnProcessor(nn.Module):
11571164
operator.
11581165
"""
11591166

1160-
def __init__(self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None):
1167+
def __init__(
1168+
self, hidden_size, cross_attention_dim, rank=4, attention_op: Optional[Callable] = None, network_alpha=None
1169+
):
11611170
super().__init__()
11621171

11631172
self.hidden_size = hidden_size
11641173
self.cross_attention_dim = cross_attention_dim
11651174
self.rank = rank
11661175
self.attention_op = attention_op
11671176

1168-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
1169-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
1170-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank)
1171-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank)
1177+
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
1178+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1179+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
1180+
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
11721181

11731182
def __call__(
11741183
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None

0 commit comments

Comments
 (0)