Skip to content

Commit 285b490

Browse files
sayakpaulpatrickvonplatenisidentical
authored andcommitted
[Feat] Support SDXL Kohya-style LoRA (huggingface#4287)
* sdxl lora changes. * better name replacement. * better replacement. * debugging * debugging * debugging * debugging * debugging * remove print. * print state dict keys. * print * distingisuih better * debuggable. * fxi: tyests * fix: arg from training script. * access from class. * run style * debug * save intermediate * some simplifications for SDXL LoRA * styling * unet config is not needed in diffusers format. * fix: dynamic SGM block mapping for SDXL kohya loras (huggingface#4322) * Use lora compatible layers for linear proj_in/proj_out (huggingface#4323) * improve condition for using the sgm_diffusers mapping * informative comment. * load compatible keys and embedding layer maaping. * Get SDXL 1.0 example lora to load * simplify * specif ranks and hidden sizes. * better handling of k rank and hidden * debug * debug * debug * debug * debug * fix: alpha keys * add check for handling LoRAAttnAddedKVProcessor * sanity comment * modifications for text encoder SDXL * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * denugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * debugging * up * up * up * up * up * up * unneeded comments. * unneeded comments. * kwargs for the other attention processors. * kwargs for the other attention processors. * debugging * debugging * debugging * debugging * improve * debugging * debugging * more print * Fix alphas * debugging * debugging * debugging * debugging * debugging * debugging * clean up * clean up. * debugging * fix: text --------- Co-authored-by: Patrick von Platen <[email protected]> Co-authored-by: Batuhan Taskaya <[email protected]>
1 parent e668811 commit 285b490

File tree

10 files changed

+553
-173
lines changed

10 files changed

+553
-173
lines changed

docs/source/en/training/lora.md

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -354,4 +354,52 @@ directly with [`~diffusers.loaders.LoraLoaderMixin.load_lora_weights`] like so:
354354
lora_model_id = "sayakpaul/civitai-light-shadow-lora"
355355
lora_filename = "light_and_shadow.safetensors"
356356
pipeline.load_lora_weights(lora_model_id, weight_name=lora_filename)
357-
```
357+
```
358+
359+
### Supporting Stable Diffusion XL LoRAs trained using the Kohya-trainer
360+
361+
With this [PR](https://github.com/huggingface/diffusers/pull/4287), there should now be better support for loading Kohya-style LoRAs trained on Stable Diffusion XL (SDXL).
362+
363+
Here are some example checkpoints we tried out:
364+
365+
* SDXL 0.9:
366+
* https://civitai.com/models/22279?modelVersionId=118556
367+
* https://civitai.com/models/104515/sdxlor30costumesrevue-starlight-saijoclaudine-lora
368+
* https://civitai.com/models/108448/daiton-sdxl-test
369+
* https://filebin.net/2ntfqqnapiu9q3zx/pixelbuildings128-v1.safetensors
370+
* SDXL 1.0:
371+
* https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0/blob/main/sd_xl_offset_example-lora_1.0.safetensors
372+
373+
Here is an example of how to perform inference with these checkpoints in `diffusers`:
374+
375+
```python
376+
from diffusers import DiffusionPipeline
377+
import torch
378+
379+
base_model_id = "stabilityai/stable-diffusion-xl-base-0.9"
380+
pipeline = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
381+
pipeline.load_lora_weights(".", weight_name="Kamepan.safetensors")
382+
383+
prompt = "anime screencap, glint, drawing, best quality, light smile, shy, a full body of a girl wearing wedding dress in the middle of the forest beneath the trees, fireflies, big eyes, 2d, cute, anime girl, waifu, cel shading, magical girl, vivid colors, (outline:1.1), manga anime artstyle, masterpiece, offical wallpaper, glint <lora:kame_sdxl_v2:1>"
384+
negative_prompt = "(deformed, bad quality, sketch, depth of field, blurry:1.1), grainy, bad anatomy, bad perspective, old, ugly, realistic, cartoon, disney, bad propotions"
385+
generator = torch.manual_seed(2947883060)
386+
num_inference_steps = 30
387+
guidance_scale = 7
388+
389+
image = pipeline(
390+
prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=num_inference_steps,
391+
generator=generator, guidance_scale=guidance_scale
392+
).images[0]
393+
image.save("Kamepan.png")
394+
```
395+
396+
`Kamepan.safetensors` comes from https://civitai.com/models/22279?modelVersionId=118556 .
397+
398+
If you notice carefully, the inference UX is exactly identical to what we presented in the sections above.
399+
400+
Thanks to [@isidentical](https://github.com/isidentical) for helping us on integrating this feature.
401+
402+
### Known limitations specific to the Kohya-styled LoRAs
403+
404+
* SDXL LoRAs that have both the text encoders are currently leading to weird results. We're actively investigating the issue.
405+
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -925,10 +925,10 @@ def load_model_hook(models, input_dir):
925925
else:
926926
raise ValueError(f"unexpected save model: {model.__class__}")
927927

928-
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
929-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
928+
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
929+
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
930930
LoraLoaderMixin.load_lora_into_text_encoder(
931-
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
931+
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_
932932
)
933933

934934
accelerator.register_save_state_pre_hook(save_model_hook)

examples/dreambooth/train_dreambooth_lora_sdxl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -825,13 +825,13 @@ def load_model_hook(models, input_dir):
825825
else:
826826
raise ValueError(f"unexpected save model: {model.__class__}")
827827

828-
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
829-
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
828+
lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
829+
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)
830830
LoraLoaderMixin.load_lora_into_text_encoder(
831-
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
831+
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
832832
)
833833
LoraLoaderMixin.load_lora_into_text_encoder(
834-
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
834+
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
835835
)
836836

837837
accelerator.register_save_state_pre_hook(save_model_hook)

0 commit comments

Comments
 (0)