Skip to content

Commit fde2bf8

Browse files
committed
refactor to support patching LoRA into T5
1 parent ce55049 commit fde2bf8

File tree

6 files changed

+107
-83
lines changed

6 files changed

+107
-83
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
from torch.utils.data import Dataset
3939
from torchvision import transforms
4040
from tqdm.auto import tqdm
41-
from transformers import AutoTokenizer, PretrainedConfig
41+
from transformers import AutoTokenizer, CLIPTextModel, PretrainedConfig, T5EncoderModel
4242

4343
import diffusers
4444
from diffusers import (
@@ -49,7 +49,7 @@
4949
StableDiffusionPipeline,
5050
UNet2DConditionModel,
5151
)
52-
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
52+
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, text_encoder_attn_modules
5353
from diffusers.models.attention_processor import (
5454
AttnAddedKVProcessor,
5555
AttnAddedKVProcessor2_0,
@@ -59,7 +59,7 @@
5959
SlicedAttnAddedKVProcessor,
6060
)
6161
from diffusers.optimization import get_scheduler
62-
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
62+
from diffusers.utils import check_min_version, is_wandb_available
6363
from diffusers.utils.import_utils import is_xformers_available
6464

6565

@@ -857,23 +857,25 @@ def main(args):
857857
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
858858

859859
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
860-
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
861-
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
860+
# So, instead, we monkey-patch the forward calls of its attention-blocks.
862861
text_encoder_lora_layers = None
863862
if args.train_text_encoder:
864863
text_lora_attn_procs = {}
865-
for name, module in text_encoder.named_modules():
866-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
867-
text_lora_attn_procs[name] = LoRAAttnProcessor(
868-
hidden_size=module.out_proj.out_features, cross_attention_dim=None
869-
)
864+
865+
for name, module in text_encoder_attn_modules(text_encoder):
866+
if isinstance(text_encoder, CLIPTextModel):
867+
hidden_size = module.out_proj.out_features
868+
inner_dim = None
869+
elif isinstance(text_encoder, T5EncoderModel):
870+
hidden_size = module.d_model
871+
inner_dim = module.inner_dim
872+
else:
873+
raise ValueError(f"{text_encoder.__class__.__name__} does not support LoRA training")
874+
875+
text_lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, inner_dim=inner_dim)
876+
870877
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
871-
temp_pipeline = DiffusionPipeline.from_pretrained(
872-
args.pretrained_model_name_or_path, text_encoder=text_encoder
873-
)
874-
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
875-
text_encoder = temp_pipeline.text_encoder
876-
del temp_pipeline
878+
LoraLoaderMixin._modify_text_encoder(text_lora_attn_procs, text_encoder)
877879

878880
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
879881
def save_model_hook(models, weights, output_dir):

src/diffusers/loaders.py

Lines changed: 73 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch
2121
import torch.nn.functional as F
2222
from huggingface_hub import hf_hub_download
23+
from torch import nn
2324

2425
from .models.attention_processor import (
2526
AttnAddedKVProcessor,
@@ -36,7 +37,6 @@
3637
from .utils import (
3738
DIFFUSERS_CACHE,
3839
HF_HUB_OFFLINE,
39-
TEXT_ENCODER_ATTN_MODULE,
4040
_get_model_file,
4141
deprecate,
4242
is_safetensors_available,
@@ -49,7 +49,7 @@
4949
import safetensors
5050

5151
if is_transformers_available():
52-
from transformers import PreTrainedModel, PreTrainedTokenizer
52+
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer, T5EncoderModel
5353

5454

5555
logger = logging.get_logger(__name__)
@@ -67,6 +67,36 @@
6767
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
6868

6969

70+
class PatchedLoraProjection(nn.Module):
71+
def __init__(self, regular_linear_layer, lora_linear_layer, lora_scale=1):
72+
super().__init__()
73+
self.regular_linear_layer = regular_linear_layer
74+
self.lora_linear_layer = lora_linear_layer
75+
self.lora_scale = lora_scale
76+
77+
def forward(self, input):
78+
return self.regular_linear_layer(input) + self.lora_scale * self.lora_linear_layer(input)
79+
80+
81+
def text_encoder_attn_modules(text_encoder):
82+
attn_modules = []
83+
84+
if isinstance(text_encoder, CLIPTextModel):
85+
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
86+
name = f"text_model.encoder.layers.{i}.self_attn"
87+
mod = layer.self_attn
88+
attn_modules.append((name, mod))
89+
elif isinstance(text_encoder, T5EncoderModel):
90+
for i, block in enumerate(text_encoder.encoder.block):
91+
name = f"encoder.block.{i}.layer.0.SelfAttention"
92+
mod = block.layer[0].SelfAttention
93+
attn_modules.append((name, mod))
94+
else:
95+
raise ValueError(f"do not know how to get attention modules for: {text_encoder.__class__.__name__}")
96+
97+
return attn_modules
98+
99+
70100
class AttnProcsLayers(torch.nn.Module):
71101
def __init__(self, state_dict: Dict[str, torch.Tensor]):
72102
super().__init__()
@@ -942,7 +972,7 @@ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Di
942972
attn_procs_text_encoder = self._load_text_encoder_attn_procs(
943973
text_encoder_lora_state_dict, network_alpha=network_alpha
944974
)
945-
self._modify_text_encoder(attn_procs_text_encoder)
975+
self._modify_text_encoder(attn_procs_text_encoder, self.text_encoder, self.lora_scale)
946976

947977
# save lora attn procs of text encoder so that it can be easily retrieved
948978
self._text_encoder_lora_attn_procs = attn_procs_text_encoder
@@ -968,20 +998,24 @@ def text_encoder_lora_attn_procs(self):
968998
return self._text_encoder_lora_attn_procs
969999
return
9701000

971-
def _remove_text_encoder_monkey_patch(self):
972-
# Loop over the CLIPAttention module of text_encoder
973-
for name, attn_module in self.text_encoder.named_modules():
974-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
975-
# Loop over the LoRA layers
976-
for _, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
977-
# Retrieve the q/k/v/out projection of CLIPAttention
978-
module = attn_module.get_submodule(text_encoder_attr)
979-
if hasattr(module, "old_forward"):
980-
# restore original `forward` to remove monkey-patch
981-
module.forward = module.old_forward
982-
delattr(module, "old_forward")
983-
984-
def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
1001+
@classmethod
1002+
def _remove_text_encoder_monkey_patch(cls, text_encoder):
1003+
for _, attn_module in text_encoder_attn_modules(text_encoder):
1004+
if isinstance(text_encoder, CLIPTextModel):
1005+
attn_module.q_proj = attn_module.q_proj.regular_linear_layer
1006+
attn_module.k_proj = attn_module.k_proj.regular_linear_layer
1007+
attn_module.v_proj = attn_module.v_proj.regular_linear_layer
1008+
attn_module.out_proj = attn_module.out_proj.regular_linear_layer
1009+
elif isinstance(text_encoder, T5EncoderModel):
1010+
attn_module.q = attn_module.q.regular_linear_layer
1011+
attn_module.k = attn_module.k.regular_linear_layer
1012+
attn_module.v = attn_module.v.regular_linear_layer
1013+
attn_module.o = attn_module.o.regular_linear_layer
1014+
else:
1015+
raise ValueError(f"{text_encoder.__class__.__name__} does not support LoRA training")
1016+
1017+
@classmethod
1018+
def _modify_text_encoder(cls, attn_processors: Dict[str, LoRAAttnProcessor], text_encoder, lora_scale=1):
9851019
r"""
9861020
Monkey-patches the forward passes of attention modules of the text encoder.
9871021
@@ -991,40 +1025,29 @@ def _modify_text_encoder(self, attn_processors: Dict[str, LoRAAttnProcessor]):
9911025
"""
9921026

9931027
# First, remove any monkey-patch that might have been applied before
994-
self._remove_text_encoder_monkey_patch()
1028+
cls._remove_text_encoder_monkey_patch(text_encoder)
9951029

996-
# Loop over the CLIPAttention module of text_encoder
997-
for name, attn_module in self.text_encoder.named_modules():
998-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
999-
# Loop over the LoRA layers
1000-
for attn_proc_attr, text_encoder_attr in self._lora_attn_processor_attr_to_text_encoder_attr.items():
1001-
# Retrieve the q/k/v/out projection of CLIPAttention and its corresponding LoRA layer.
1002-
module = attn_module.get_submodule(text_encoder_attr)
1003-
lora_layer = attn_processors[name].get_submodule(attn_proc_attr)
1004-
1005-
# save old_forward to module that can be used to remove monkey-patch
1006-
old_forward = module.old_forward = module.forward
1007-
1008-
# create a new scope that locks in the old_forward, lora_layer value for each new_forward function
1009-
# for more detail, see https://github.com/huggingface/diffusers/pull/3490#issuecomment-1555059060
1010-
def make_new_forward(old_forward, lora_layer):
1011-
def new_forward(x):
1012-
result = old_forward(x) + self.lora_scale * lora_layer(x)
1013-
return result
1014-
1015-
return new_forward
1016-
1017-
# Monkey-patch.
1018-
module.forward = make_new_forward(old_forward, lora_layer)
1019-
1020-
@property
1021-
def _lora_attn_processor_attr_to_text_encoder_attr(self):
1022-
return {
1023-
"to_q_lora": "q_proj",
1024-
"to_k_lora": "k_proj",
1025-
"to_v_lora": "v_proj",
1026-
"to_out_lora": "out_proj",
1027-
}
1030+
for name, attn_module in text_encoder_attn_modules(text_encoder):
1031+
if isinstance(text_encoder, CLIPTextModel):
1032+
attn_module.q_proj = PatchedLoraProjection(
1033+
attn_module.q_proj, attn_processors[name].to_q_lora, lora_scale
1034+
)
1035+
attn_module.k_proj = PatchedLoraProjection(
1036+
attn_module.k_proj, attn_processors[name].to_k_lora, lora_scale
1037+
)
1038+
attn_module.v_proj = PatchedLoraProjection(
1039+
attn_module.v_proj, attn_processors[name].to_v_lora, lora_scale
1040+
)
1041+
attn_module.out_proj = PatchedLoraProjection(
1042+
attn_module.out_proj, attn_processors[name].to_out_lora, lora_scale
1043+
)
1044+
elif isinstance(text_encoder, T5EncoderModel):
1045+
attn_module.q = PatchedLoraProjection(attn_module.q, attn_processors[name].to_q_lora, lora_scale)
1046+
attn_module.k = PatchedLoraProjection(attn_module.k, attn_processors[name].to_k_lora, lora_scale)
1047+
attn_module.v = PatchedLoraProjection(attn_module.v, attn_processors[name].to_v_lora, lora_scale)
1048+
attn_module.o = PatchedLoraProjection(attn_module.o, attn_processors[name].to_out_lora, lora_scale)
1049+
else:
1050+
raise ValueError(f"{text_encoder.__class__.__name__} does not support LoRA training")
10281051

10291052
def _load_text_encoder_attn_procs(
10301053
self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs

src/diffusers/models/attention_processor.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -549,17 +549,19 @@ class LoRAAttnProcessor(nn.Module):
549549
Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs.
550550
"""
551551

552-
def __init__(self, hidden_size, cross_attention_dim=None, rank=4, network_alpha=None):
552+
def __init__(self, hidden_size, cross_attention_dim=None, inner_dim=None, rank=4, network_alpha=None):
553553
super().__init__()
554554

555555
self.hidden_size = hidden_size
556556
self.cross_attention_dim = cross_attention_dim
557557
self.rank = rank
558+
if inner_dim is None:
559+
inner_dim = hidden_size
558560

559-
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
560-
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
561-
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
562-
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
561+
self.to_q_lora = LoRALinearLayer(hidden_size, inner_dim, rank, network_alpha)
562+
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, inner_dim, rank, network_alpha)
563+
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, inner_dim, rank, network_alpha)
564+
self.to_out_lora = LoRALinearLayer(inner_dim, hidden_size, rank, network_alpha)
563565

564566
def __call__(
565567
self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0, temb=None

src/diffusers/utils/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
ONNX_EXTERNAL_WEIGHTS_NAME,
3131
ONNX_WEIGHTS_NAME,
3232
SAFETENSORS_WEIGHTS_NAME,
33-
TEXT_ENCODER_ATTN_MODULE,
3433
WEIGHTS_NAME,
3534
)
3635
from .deprecation_utils import deprecate

src/diffusers/utils/constants.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,3 @@
3030
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
3131
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
3232
DEPRECATED_REVISION_ARGS = ["fp16", "non-ema"]
33-
TEXT_ENCODER_ATTN_MODULE = ".self_attn"

tests/models/test_lora_layers.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
2424

2525
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
26-
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
26+
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, text_encoder_attn_modules
2727
from diffusers.models.attention_processor import (
2828
Attention,
2929
AttnProcessor,
@@ -33,7 +33,7 @@
3333
LoRAXFormersAttnProcessor,
3434
XFormersAttnProcessor,
3535
)
36-
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, floats_tensor, torch_device
36+
from diffusers.utils import floats_tensor, torch_device
3737

3838

3939
def create_unet_lora_layers(unet: nn.Module):
@@ -63,11 +63,10 @@ def create_text_encoder_lora_attn_procs(text_encoder: nn.Module):
6363
lora_attn_processor_class = (
6464
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
6565
)
66-
for name, module in text_encoder.named_modules():
67-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
68-
text_lora_attn_procs[name] = lora_attn_processor_class(
69-
hidden_size=module.out_proj.out_features, cross_attention_dim=None
70-
)
66+
for name, module in text_encoder_attn_modules(text_encoder):
67+
text_lora_attn_procs[name] = lora_attn_processor_class(
68+
hidden_size=module.out_proj.out_features, cross_attention_dim=None
69+
)
7170
return text_lora_attn_procs
7271

7372

@@ -286,7 +285,7 @@ def test_text_encoder_lora_monkey_patch(self):
286285
set_lora_up_weights(text_attn_procs, randn_weight=False)
287286

288287
# monkey patch
289-
pipe._modify_text_encoder(text_attn_procs)
288+
pipe._modify_text_encoder(text_attn_procs, pipe.text_encoder, pipe.lora_scale)
290289

291290
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
292291
del text_attn_procs
@@ -305,7 +304,7 @@ def test_text_encoder_lora_monkey_patch(self):
305304
set_lora_up_weights(text_attn_procs, randn_weight=True)
306305

307306
# monkey patch
308-
pipe._modify_text_encoder(text_attn_procs)
307+
pipe._modify_text_encoder(text_attn_procs, pipe.text_encoder, pipe.lora_scale)
309308

310309
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
311310
del text_attn_procs
@@ -334,7 +333,7 @@ def test_text_encoder_lora_remove_monkey_patch(self):
334333
set_lora_up_weights(text_attn_procs, randn_weight=True)
335334

336335
# monkey patch
337-
pipe._modify_text_encoder(text_attn_procs)
336+
pipe._modify_text_encoder(text_attn_procs, pipe.text_encoder, pipe.lora_scale)
338337

339338
# verify that it's okay to release the text_attn_procs which holds the LoRAAttnProcessor.
340339
del text_attn_procs

0 commit comments

Comments
 (0)