Skip to content

Add flag to disable LoRA monkey-patching #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: kohya-lora-loader
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions src/diffusers/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -821,11 +821,13 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.

mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
text_encoder_modify_forwards(`bool`, *optional*, defaults to `True`):
Whether or not to monkey-patch the forward pass of the text encoder to use the LoRA layers.
Monkey-patching should only happen once, so set this flag to False if you call this function more than once.

<Tip>

Expand All @@ -846,6 +848,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", None)
use_safetensors = kwargs.pop("use_safetensors", None)
text_encoder_modify_forwards = kwargs.pop("text_encoder_modify_forwards", True)


if use_safetensors and not is_safetensors_available():
raise ValueError(
Expand Down Expand Up @@ -934,7 +938,8 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
text_encoder_lora_state_dict, network_alpha=network_alpha
)
self._modify_text_encoder(attn_procs_text_encoder)
if text_encoder_modify_forwards:
self._modify_text_encoder(attn_procs_text_encoder)

# save lora attn procs of text encoder so that it can be easily retrieved
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
Expand Down