Skip to content

Commit 9e21c54

Browse files
authored
Add function to remove monkey-patch for text encoder LoRA (huggingface#3649)
* merge undoable-monkeypatch * remove TEXT_ENCODER_TARGET_MODULES, refactoring * move create_lora_weight_file
1 parent 857e7aa commit 9e21c54

File tree

3 files changed

+49
-34
lines changed

3 files changed

+49
-34
lines changed

loaders.py

Lines changed: 49 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .utils import (
3535
DIFFUSERS_CACHE,
3636
HF_HUB_OFFLINE,
37-
TEXT_ENCODER_TARGET_MODULES,
37+
TEXT_ENCODER_ATTN_MODULE,
3838
_get_model_file,
3939
deprecate,
4040
is_safetensors_available,
@@ -955,6 +955,19 @@ def text_encoder_lora_attn_procs(self):
955955
return self._text_encoder_lora_attn_procs
956956
return
957957

958+
def _remove_text_encoder_monkey_patch(self):
959+
# Loop over the CLIPAttention module of text_encoder
960+
for name, attn_module in self.text_encoder.named_modules():
961+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
962+
# Loop over the LoRA layers
963+
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
964+
# Retrieve the q/k/v/out projection of CLIPAttention
965+
module = attn_module.get_submodule(text_encoder_attr)
966+
if hasattr(module, "old_forward"):
967+
# restore original `forward` to remove monkey-patch
968+
module.forward = module.old_forward
969+
delattr(module, "old_forward")
970+
958971
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
959972
r"""
960973
Monkey-patches the forward passes of attention modules of the text encoder.
@@ -963,37 +976,41 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
963976
attn_processors: Dict[str, `LoRAAttnProcessor`]:
964977
A dictionary mapping the module names and their corresponding [`~LoRAAttnProcessor`].
965978
"""
966-
# Loop over the original attention modules.
967-
for name, _ in self.text_encoder.named_modules():
968-
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
969-
# Retrieve the module and its corresponding LoRA processor.
970-
module = self.text_encoder.get_submodule(name)
971-
# Construct a new function that performs the LoRA merging. We will monkey patch
972-
# this forward pass.
973-
attn_processor_name = ".".join(name.split(".")[:-1])
974-
lora_layer = getattr(attn_processors[attn_processor_name], self._get_lora_layer_attribute(name))
975-
old_forward = module.forward
976-
977-
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
978-
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
979-
def make_new_forward(old_forward, lora_layer):
980-
def new_forward(x):
981-
return old_forward(x) + lora_layer(x)
982-
983-
return new_forward
984-
985-
# Monkey-patch.
986-
module.forward = make_new_forward(old_forward, lora_layer)
987-
988-
def _get_lora_layer_attribute(self, name: str) -> str:
989-
if "q_proj" in name:
990-
return "to_q_lora"
991-
elif "v_proj" in name:
992-
return "to_v_lora"
993-
elif "k_proj" in name:
994-
return "to_k_lora"
995-
else:
996-
return "to_out_lora"
979+
980+
# First, remove any monkey-patch that might have been applied before
981+
self._remove_text_encoder_monkey_patch()
982+
983+
# Loop over the CLIPAttention module of text_encoder
984+
for name, attn_module in self.text_encoder.named_modules():
985+
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
986+
# Loop over the LoRA layers
987+
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
988+
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
989+
module = attn_module.get_submodule(text_encoder_attr)
990+
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
991+
992+
# save old_forward to module that can be used to remove monkey-patch
993+
old_forward = module.old_forward = module.forward
994+
995+
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
996+
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
997+
def make_new_forward(old_forward, lora_layer):
998+
def new_forward(x):
999+
return old_forward(x) + lora_layer(x)
1000+
1001+
return new_forward
1002+
1003+
# Monkey-patch.
1004+
module.forward = make_new_forward(old_forward, lora_layer)
1005+
1006+
@property
1007+
def _lora_attn_processor_attr_to_text_encoder_attr(self):
1008+
return {
1009+
"to_q_lora": "q_proj",
1010+
"to_k_lora": "k_proj",
1011+
"to_v_lora": "v_proj",
1012+
"to_out_lora": "out_proj",
1013+
}
9971014

9981015
def _load_text_encoder_attn_procs(
9991016
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs

utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
ONNX_WEIGHTS_NAME,
3232
SAFETENSORS_WEIGHTS_NAME,
3333
TEXT_ENCODER_ATTN_MODULE,
34-
TEXT_ENCODER_TARGET_MODULES,
3534
WEIGHTS_NAME,
3635
)
3736
from .deprecation_utils import deprecate

utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,5 +30,4 @@
3030
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
3131
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
3232
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33-
TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj", "k_proj", "out_proj"]
3433
TEXT_ENCODER_ATTN_MODULE = ".self_attn"

0 commit comments

Comments
 (0)