2020import torch
2121import torch .nn .functional as F
2222from huggingface_hub import hf_hub_download
23+ from torch import nn
2324
2425from .models .attention_processor import (
2526 AttnAddedKVProcessor ,
3637from .utils import (
3738 DIFFUSERS_CACHE ,
3839 HF_HUB_OFFLINE ,
39- TEXT_ENCODER_ATTN_MODULE ,
4040 _get_model_file ,
4141 deprecate ,
4242 is_safetensors_available ,
4949 import safetensors
5050
5151if is_transformers_available ():
52- from transformers import PreTrainedModel , PreTrainedTokenizer
52+ from transformers import CLIPTextModel , PreTrainedModel , PreTrainedTokenizer , T5EncoderModel
5353
5454
5555logger = logging .get_logger (__name__ )
6767CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
6868
6969
70+ class PatchedLoraProjection (nn .Module ):
71+ def __init__ (self , regular_linear_layer , lora_linear_layer , lora_scale = 1 ):
72+ super ().__init__ ()
73+ self .regular_linear_layer = regular_linear_layer
74+ self .lora_linear_layer = lora_linear_layer
75+ self .lora_scale = lora_scale
76+
77+ def forward (self , input ):
78+ return self .regular_linear_layer (input ) + self .lora_scale * self .lora_linear_layer (input )
79+
80+
81+ def text_encoder_attn_modules (text_encoder ):
82+ attn_modules = []
83+
84+ if isinstance (text_encoder , CLIPTextModel ):
85+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
86+ name = f"text_model.encoder.layers.{ i } .self_attn"
87+ mod = layer .self_attn
88+ attn_modules .append ((name , mod ))
89+ elif isinstance (text_encoder , T5EncoderModel ):
90+ for i , block in enumerate (text_encoder .encoder .block ):
91+ name = f"encoder.block.{ i } .layer.0.SelfAttention"
92+ mod = block .layer [0 ].SelfAttention
93+ attn_modules .append ((name , mod ))
94+ else :
95+ raise ValueError (f"do not know how to get attention modules for: { text_encoder .__class__ .__name__ } " )
96+
97+ return attn_modules
98+
99+
70100class AttnProcsLayers (torch .nn .Module ):
71101 def __init__ (self , state_dict : Dict [str , torch .Tensor ]):
72102 super ().__init__ ()
@@ -942,7 +972,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
942972 attn_procs_text_encoder = self ._load_text_encoder_attn_procs (
943973 text_encoder_lora_state_dict , network_alpha = network_alpha
944974 )
945- self ._modify_text_encoder (attn_procs_text_encoder )
975+ self ._modify_text_encoder (attn_procs_text_encoder , self . text_encoder , self . lora_scale )
946976
947977 # save lora attn procs of text encoder so that it can be easily retrieved
948978 self ._text_encoder_lora_attn_procs = attn_procs_text_encoder
@@ -968,20 +998,24 @@ def text_encoder_lora_attn_procs(self):
968998 return self ._text_encoder_lora_attn_procs
969999 return
9701000
971- def _remove_text_encoder_monkey_patch (self ):
972- # Loop over the CLIPAttention module of text_encoder
973- for name , attn_module in self .text_encoder .named_modules ():
974- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
975- # Loop over the LoRA layers
976- for _ , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
977- # Retrieve the q/k/v/out projection of CLIPAttention
978- module = attn_module .get_submodule (text_encoder_attr )
979- if hasattr (module , "old_forward" ):
980- # restore original `forward` to remove monkey-patch
981- module .forward = module .old_forward
982- delattr (module , "old_forward" )
983-
984- def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
1001+ @classmethod
1002+ def _remove_text_encoder_monkey_patch (cls , text_encoder ):
1003+ for _ , attn_module in text_encoder_attn_modules (text_encoder ):
1004+ if isinstance (text_encoder , CLIPTextModel ):
1005+ attn_module .q_proj = attn_module .q_proj .regular_linear_layer
1006+ attn_module .k_proj = attn_module .k_proj .regular_linear_layer
1007+ attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1008+ attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1009+ elif isinstance (text_encoder , T5EncoderModel ):
1010+ attn_module .q = attn_module .q .regular_linear_layer
1011+ attn_module .k = attn_module .k .regular_linear_layer
1012+ attn_module .v = attn_module .v .regular_linear_layer
1013+ attn_module .o = attn_module .o .regular_linear_layer
1014+ else :
1015+ raise ValueError (f"{ text_encoder .__class__ .__name__ } does not support LoRA training" )
1016+
1017+ @classmethod
1018+ def _modify_text_encoder (cls , attn_processors : Dict [str , LoRAAttnProcessor ], text_encoder , lora_scale = 1 ):
9851019 r"""
9861020 Monkey-patches the forward passes of attention modules of the text encoder.
9871021
@@ -991,40 +1025,29 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
9911025 """
9921026
9931027 # First, remove any monkey-patch that might have been applied before
994- self ._remove_text_encoder_monkey_patch ()
1028+ cls ._remove_text_encoder_monkey_patch (text_encoder )
9951029
996- # Loop over the CLIPAttention module of text_encoder
997- for name , attn_module in self .text_encoder .named_modules ():
998- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
999- # Loop over the LoRA layers
1000- for attn_proc_attr , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
1001- # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
1002- module = attn_module .get_submodule (text_encoder_attr )
1003- lora_layer = attn_processors [name ].get_submodule (attn_proc_attr )
1004-
1005- # save old_forward to module that can be used to remove monkey-patch
1006- old_forward = module .old_forward = module .forward
1007-
1008- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
1009- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
1010- def make_new_forward (old_forward , lora_layer ):
1011- def new_forward (x ):
1012- result = old_forward (x ) + self .lora_scale * lora_layer (x )
1013- return result
1014-
1015- return new_forward
1016-
1017- # Monkey-patch.
1018- module .forward = make_new_forward (old_forward , lora_layer )
1019-
1020- @property
1021- def _lora_attn_processor_attr_to_text_encoder_attr (self ):
1022- return {
1023- "to_q_lora" : "q_proj" ,
1024- "to_k_lora" : "k_proj" ,
1025- "to_v_lora" : "v_proj" ,
1026- "to_out_lora" : "out_proj" ,
1027- }
1030+ for name , attn_module in text_encoder_attn_modules (text_encoder ):
1031+ if isinstance (text_encoder , CLIPTextModel ):
1032+ attn_module .q_proj = PatchedLoraProjection (
1033+ attn_module .q_proj , attn_processors [name ].to_q_lora , lora_scale
1034+ )
1035+ attn_module .k_proj = PatchedLoraProjection (
1036+ attn_module .k_proj , attn_processors [name ].to_k_lora , lora_scale
1037+ )
1038+ attn_module .v_proj = PatchedLoraProjection (
1039+ attn_module .v_proj , attn_processors [name ].to_v_lora , lora_scale
1040+ )
1041+ attn_module .out_proj = PatchedLoraProjection (
1042+ attn_module .out_proj , attn_processors [name ].to_out_lora , lora_scale
1043+ )
1044+ elif isinstance (text_encoder , T5EncoderModel ):
1045+ attn_module .q = PatchedLoraProjection (attn_module .q , attn_processors [name ].to_q_lora , lora_scale )
1046+ attn_module .k = PatchedLoraProjection (attn_module .k , attn_processors [name ].to_k_lora , lora_scale )
1047+ attn_module .v = PatchedLoraProjection (attn_module .v , attn_processors [name ].to_v_lora , lora_scale )
1048+ attn_module .o = PatchedLoraProjection (attn_module .o , attn_processors [name ].to_out_lora , lora_scale )
1049+ else :
1050+ raise ValueError (f"{ text_encoder .__class__ .__name__ } does not support LoRA training" )
10281051
10291052 def _load_text_encoder_attn_procs (
10301053 self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs
0 commit comments