20
20
import torch
21
21
import torch .nn .functional as F
22
22
from huggingface_hub import hf_hub_download
23
+ from torch import nn
23
24
24
25
from .models .attention_processor import (
25
26
AttnAddedKVProcessor ,
36
37
from .utils import (
37
38
DIFFUSERS_CACHE ,
38
39
HF_HUB_OFFLINE ,
39
- TEXT_ENCODER_ATTN_MODULE ,
40
40
_get_model_file ,
41
41
deprecate ,
42
42
is_safetensors_available ,
49
49
import safetensors
50
50
51
51
if is_transformers_available ():
52
- from transformers import PreTrainedModel , PreTrainedTokenizer
52
+ from transformers import CLIPTextModel , PreTrainedModel , PreTrainedTokenizer , T5EncoderModel
53
53
54
54
55
55
logger = logging .get_logger (__name__ )
67
67
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
68
68
69
69
70
+ class PatchedLoraProjection (nn .Module ):
71
+ def __init__ (self , regular_linear_layer , lora_linear_layer , lora_scale = 1 ):
72
+ super ().__init__ ()
73
+ self .regular_linear_layer = regular_linear_layer
74
+ self .lora_linear_layer = lora_linear_layer
75
+ self .lora_scale = lora_scale
76
+
77
+ def forward (self , input ):
78
+ return self .regular_linear_layer (input ) + self .lora_scale * self .lora_linear_layer (input )
79
+
80
+
81
+ def text_encoder_attn_modules (text_encoder ):
82
+ attn_modules = []
83
+
84
+ if isinstance (text_encoder , CLIPTextModel ):
85
+ for i , layer in enumerate (text_encoder .text_model .encoder .layers ):
86
+ name = f"text_model.encoder.layers.{ i } .self_attn"
87
+ mod = layer .self_attn
88
+ attn_modules .append ((name , mod ))
89
+ elif isinstance (text_encoder , T5EncoderModel ):
90
+ for i , block in enumerate (text_encoder .encoder .block ):
91
+ name = f"encoder.block.{ i } .layer.0.SelfAttention"
92
+ mod = block .layer [0 ].SelfAttention
93
+ attn_modules .append ((name , mod ))
94
+ else :
95
+ raise ValueError (f"do not know how to get attention modules for: { text_encoder .__class__ .__name__ } " )
96
+
97
+ return attn_modules
98
+
99
+
70
100
class AttnProcsLayers (torch .nn .Module ):
71
101
def __init__ (self , state_dict : Dict [str , torch .Tensor ]):
72
102
super ().__init__ ()
@@ -942,7 +972,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
942
972
attn_procs_text_encoder = self ._load_text_encoder_attn_procs (
943
973
text_encoder_lora_state_dict , network_alpha = network_alpha
944
974
)
945
- self ._modify_text_encoder (attn_procs_text_encoder )
975
+ self ._modify_text_encoder (attn_procs_text_encoder , self . text_encoder , self . lora_scale )
946
976
947
977
# save lora attn procs of text encoder so that it can be easily retrieved
948
978
self ._text_encoder_lora_attn_procs = attn_procs_text_encoder
@@ -968,20 +998,24 @@ def text_encoder_lora_attn_procs(self):
968
998
return self ._text_encoder_lora_attn_procs
969
999
return
970
1000
971
- def _remove_text_encoder_monkey_patch (self ):
972
- # Loop over the CLIPAttention module of text_encoder
973
- for name , attn_module in self .text_encoder .named_modules ():
974
- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
975
- # Loop over the LoRA layers
976
- for _ , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
977
- # Retrieve the q/k/v/out projection of CLIPAttention
978
- module = attn_module .get_submodule (text_encoder_attr )
979
- if hasattr (module , "old_forward" ):
980
- # restore original `forward` to remove monkey-patch
981
- module .forward = module .old_forward
982
- delattr (module , "old_forward" )
983
-
984
- def _modify_text_encoder (self , attn_processors : Dict [str , LoRAAttnProcessor ]):
1001
+ @classmethod
1002
+ def _remove_text_encoder_monkey_patch (cls , text_encoder ):
1003
+ for _ , attn_module in text_encoder_attn_modules (text_encoder ):
1004
+ if isinstance (text_encoder , CLIPTextModel ):
1005
+ attn_module .q_proj = attn_module .q_proj .regular_linear_layer
1006
+ attn_module .k_proj = attn_module .k_proj .regular_linear_layer
1007
+ attn_module .v_proj = attn_module .v_proj .regular_linear_layer
1008
+ attn_module .out_proj = attn_module .out_proj .regular_linear_layer
1009
+ elif isinstance (text_encoder , T5EncoderModel ):
1010
+ attn_module .q = attn_module .q .regular_linear_layer
1011
+ attn_module .k = attn_module .k .regular_linear_layer
1012
+ attn_module .v = attn_module .v .regular_linear_layer
1013
+ attn_module .o = attn_module .o .regular_linear_layer
1014
+ else :
1015
+ raise ValueError (f"{ text_encoder .__class__ .__name__ } does not support LoRA training" )
1016
+
1017
+ @classmethod
1018
+ def _modify_text_encoder (cls , attn_processors : Dict [str , LoRAAttnProcessor ], text_encoder , lora_scale = 1 ):
985
1019
r"""
986
1020
Monkey-patches the forward passes of attention modules of the text encoder.
987
1021
@@ -991,40 +1025,29 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
991
1025
"""
992
1026
993
1027
# First, remove any monkey-patch that might have been applied before
994
- self ._remove_text_encoder_monkey_patch ()
1028
+ cls ._remove_text_encoder_monkey_patch (text_encoder )
995
1029
996
- # Loop over the CLIPAttention module of text_encoder
997
- for name , attn_module in self .text_encoder .named_modules ():
998
- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
999
- # Loop over the LoRA layers
1000
- for attn_proc_attr , text_encoder_attr in self ._lora_attn_processor_attr_to_text_encoder_attr .items ():
1001
- # Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
1002
- module = attn_module .get_submodule (text_encoder_attr )
1003
- lora_layer = attn_processors [name ].get_submodule (attn_proc_attr )
1004
-
1005
- # save old_forward to module that can be used to remove monkey-patch
1006
- old_forward = module .old_forward = module .forward
1007
-
1008
- # create a new scope that locks in the old_forward, lora_layer value for each new_forward function
1009
- # for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
1010
- def make_new_forward (old_forward , lora_layer ):
1011
- def new_forward (x ):
1012
- result = old_forward (x ) + self .lora_scale * lora_layer (x )
1013
- return result
1014
-
1015
- return new_forward
1016
-
1017
- # Monkey-patch.
1018
- module .forward = make_new_forward (old_forward , lora_layer )
1019
-
1020
- @property
1021
- def _lora_attn_processor_attr_to_text_encoder_attr (self ):
1022
- return {
1023
- "to_q_lora" : "q_proj" ,
1024
- "to_k_lora" : "k_proj" ,
1025
- "to_v_lora" : "v_proj" ,
1026
- "to_out_lora" : "out_proj" ,
1027
- }
1030
+ for name , attn_module in text_encoder_attn_modules (text_encoder ):
1031
+ if isinstance (text_encoder , CLIPTextModel ):
1032
+ attn_module .q_proj = PatchedLoraProjection (
1033
+ attn_module .q_proj , attn_processors [name ].to_q_lora , lora_scale
1034
+ )
1035
+ attn_module .k_proj = PatchedLoraProjection (
1036
+ attn_module .k_proj , attn_processors [name ].to_k_lora , lora_scale
1037
+ )
1038
+ attn_module .v_proj = PatchedLoraProjection (
1039
+ attn_module .v_proj , attn_processors [name ].to_v_lora , lora_scale
1040
+ )
1041
+ attn_module .out_proj = PatchedLoraProjection (
1042
+ attn_module .out_proj , attn_processors [name ].to_out_lora , lora_scale
1043
+ )
1044
+ elif isinstance (text_encoder , T5EncoderModel ):
1045
+ attn_module .q = PatchedLoraProjection (attn_module .q , attn_processors [name ].to_q_lora , lora_scale )
1046
+ attn_module .k = PatchedLoraProjection (attn_module .k , attn_processors [name ].to_k_lora , lora_scale )
1047
+ attn_module .v = PatchedLoraProjection (attn_module .v , attn_processors [name ].to_v_lora , lora_scale )
1048
+ attn_module .o = PatchedLoraProjection (attn_module .o , attn_processors [name ].to_out_lora , lora_scale )
1049
+ else :
1050
+ raise ValueError (f"{ text_encoder .__class__ .__name__ } does not support LoRA training" )
1028
1051
1029
1052
def _load_text_encoder_attn_procs (
1030
1053
self , pretrained_model_name_or_path_or_dict : Union [str , Dict [str , torch .Tensor ]], ** kwargs
0 commit comments