Skip to content

Commit 99e35de

Browse files
committed
refactor to support patching LoRA into T5
instantiate the lora linear layer on the same device as the regular linear layer get lora rank from state dict tests fmt can create lora layer in float32 even when rest of model is float16 fix loading model hook remove load_lora_weights_ and T5 dispatching remove Unet#attn_processors_state_dict docstrings
1 parent aed7499 commit 99e35de

File tree

6 files changed

+435
-375
lines changed

6 files changed

+435
-375
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 73 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import shutil
2424
import warnings
2525
from pathlib import Path
26+
from typing import Dict
2627

2728
import numpy as np
2829
import torch
@@ -50,7 +51,10 @@
5051
StableDiffusionPipeline,
5152
UNet2DConditionModel,
5253
)
53-
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
54+
from diffusers.loaders import (
55+
LoraLoaderMixin,
56+
text_encoder_lora_state_dict,
57+
)
5458
from diffusers.models.attention_processor import (
5559
AttnAddedKVProcessor,
5660
AttnAddedKVProcessor2_0,
@@ -60,7 +64,7 @@
6064
SlicedAttnAddedKVProcessor,
6165
)
6266
from diffusers.optimization import get_scheduler
63-
from diffusers.utils import TEXT_ENCODER_ATTN_MODULE, check_min_version, is_wandb_available
67+
from diffusers.utils import check_min_version, is_wandb_available
6468
from diffusers.utils.import_utils import is_xformers_available
6569

6670

@@ -647,6 +651,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
647651
return prompt_embeds
648652

649653

654+
def unet_attn_processors_state_dict(unet) -> Dict[str, torch.tensor]:
655+
r"""
656+
Returns:
657+
a state dict containing just the attention processor parameters.
658+
"""
659+
attn_processors = unet.attn_processors
660+
661+
attn_processors_state_dict = {}
662+
663+
for attn_processor_key, attn_processor in attn_processors.items():
664+
for parameter_key, parameter in attn_processor.state_dict().items():
665+
attn_processors_state_dict[f"{attn_processor_key}.{parameter_key}"] = parameter
666+
667+
return attn_processors_state_dict
668+
669+
650670
def main(args):
651671
logging_dir = Path(args.output_dir, args.logging_dir)
652672

@@ -827,6 +847,7 @@ def main(args):
827847

828848
# Set correct lora layers
829849
unet_lora_attn_procs = {}
850+
unet_lora_parameters = []
830851
for name, attn_processor in unet.attn_processors.items():
831852
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
832853
if name.startswith("mid_block"):
@@ -844,31 +865,17 @@ def main(args):
844865
lora_attn_processor_class = (
845866
LoRAAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else LoRAAttnProcessor
846867
)
847-
unet_lora_attn_procs[name] = lora_attn_processor_class(
848-
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
849-
)
868+
module = lora_attn_processor_class(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
869+
unet_lora_attn_procs[name] = module
870+
unet_lora_parameters.extend(module.parameters())
850871

851872
unet.set_attn_processor(unet_lora_attn_procs)
852-
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
853873

854874
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
855-
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
856-
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
857-
text_encoder_lora_layers = None
875+
# So, instead, we monkey-patch the forward calls of its attention-blocks.
858876
if args.train_text_encoder:
859-
text_lora_attn_procs = {}
860-
for name, module in text_encoder.named_modules():
861-
if name.endswith(TEXT_ENCODER_ATTN_MODULE):
862-
text_lora_attn_procs[name] = LoRAAttnProcessor(
863-
hidden_size=module.out_proj.out_features, cross_attention_dim=None
864-
)
865-
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
866-
temp_pipeline = DiffusionPipeline.from_pretrained(
867-
args.pretrained_model_name_or_path, text_encoder=text_encoder
868-
)
869-
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
870-
text_encoder = temp_pipeline.text_encoder
871-
del temp_pipeline
877+
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
878+
text_lora_parameters = LoraLoaderMixin._modify_text_encoder(text_encoder, dtype=torch.float32)
872879

873880
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
874881
def save_model_hook(models, weights, output_dir):
@@ -877,23 +884,13 @@ def save_model_hook(models, weights, output_dir):
877884
unet_lora_layers_to_save = None
878885
text_encoder_lora_layers_to_save = None
879886

880-
if args.train_text_encoder:
881-
text_encoder_keys = accelerator.unwrap_model(text_encoder_lora_layers).state_dict().keys()
882-
unet_keys = accelerator.unwrap_model(unet_lora_layers).state_dict().keys()
883-
884887
for model in models:
885-
state_dict = model.state_dict()
886-
887-
if (
888-
text_encoder_lora_layers is not None
889-
and text_encoder_keys is not None
890-
and state_dict.keys() == text_encoder_keys
891-
):
892-
# text encoder
893-
text_encoder_lora_layers_to_save = state_dict
894-
elif state_dict.keys() == unet_keys:
895-
# unet
896-
unet_lora_layers_to_save = state_dict
888+
if isinstance(model, type(accelerator.unwrap_model(unet))):
889+
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
890+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
891+
text_encoder_lora_layers_to_save = text_encoder_lora_state_dict(model)
892+
else:
893+
raise ValueError(f"unexpected save model: {model.__class__}")
897894

898895
# make sure to pop weight so that corresponding model is not saved again
899896
weights.pop()
@@ -905,27 +902,24 @@ def save_model_hook(models, weights, output_dir):
905902
)
906903

907904
def load_model_hook(models, input_dir):
908-
# Note we DON'T pass the unet and text encoder here an purpose
909-
# so that the we don't accidentally override the LoRA layers of
910-
# unet_lora_layers and text_encoder_lora_layers which are stored in `models`
911-
# with new torch.nn.Modules / weights. We simply use the pipeline class as
912-
# an easy way to load the lora checkpoints
913-
temp_pipeline = DiffusionPipeline.from_pretrained(
914-
args.pretrained_model_name_or_path,
915-
revision=args.revision,
916-
torch_dtype=weight_dtype,
917-
)
918-
temp_pipeline.load_lora_weights(input_dir)
905+
unet_ = None
906+
text_encoder_ = None
919907

920-
# load lora weights into models
921-
models[0].load_state_dict(AttnProcsLayers(temp_pipeline.unet.attn_processors).state_dict())
922-
if len(models) > 1:
923-
models[1].load_state_dict(AttnProcsLayers(temp_pipeline.text_encoder_lora_attn_procs).state_dict())
908+
while len(models) > 0:
909+
model = models.pop()
924910

925-
# delete temporary pipeline and pop models
926-
del temp_pipeline
927-
for _ in range(len(models)):
928-
models.pop()
911+
if isinstance(model, type(accelerator.unwrap_model(unet))):
912+
unet_ = model
913+
elif isinstance(model, type(accelerator.unwrap_model(text_encoder))):
914+
text_encoder_ = model
915+
else:
916+
raise ValueError(f"unexpected save model: {model.__class__}")
917+
918+
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
919+
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
920+
LoraLoaderMixin.load_lora_into_text_encoder(
921+
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
922+
)
929923

930924
accelerator.register_save_state_pre_hook(save_model_hook)
931925
accelerator.register_load_state_pre_hook(load_model_hook)
@@ -955,9 +949,9 @@ def load_model_hook(models, input_dir):
955949

956950
# Optimizer creation
957951
params_to_optimize = (
958-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
952+
itertools.chain(unet_lora_parameters, text_lora_parameters)
959953
if args.train_text_encoder
960-
else unet_lora_layers.parameters()
954+
else unet_lora_parameters
961955
)
962956
optimizer = optimizer_class(
963957
params_to_optimize,
@@ -1046,12 +1040,12 @@ def compute_text_embeddings(prompt):
10461040

10471041
# Prepare everything with our `accelerator`.
10481042
if args.train_text_encoder:
1049-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1050-
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
1043+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1044+
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
10511045
)
10521046
else:
1053-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1054-
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
1047+
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
1048+
unet, optimizer, train_dataloader, lr_scheduler
10551049
)
10561050

10571051
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1200,9 +1194,9 @@ def compute_text_embeddings(prompt):
12001194
accelerator.backward(loss)
12011195
if accelerator.sync_gradients:
12021196
params_to_clip = (
1203-
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
1197+
itertools.chain(unet_lora_parameters, text_lora_parameters)
12041198
if args.train_text_encoder
1205-
else unet_lora_layers.parameters()
1199+
else unet_lora_parameters
12061200
)
12071201
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
12081202
optimizer.step()
@@ -1291,15 +1285,17 @@ def compute_text_embeddings(prompt):
12911285
pipeline_args = {"prompt": args.validation_prompt}
12921286

12931287
if args.validation_images is None:
1294-
images = [
1295-
pipeline(**pipeline_args, generator=generator).images[0]
1296-
for _ in range(args.num_validation_images)
1297-
]
1288+
images = []
1289+
for _ in range(args.num_validation_images):
1290+
with torch.cuda.amp.autocast():
1291+
image = pipeline(**pipeline_args, generator=generator).images[0]
1292+
images.append(image)
12981293
else:
12991294
images = []
13001295
for image in args.validation_images:
13011296
image = Image.open(image)
1302-
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
1297+
with torch.cuda.amp.autocast():
1298+
image = pipeline(**pipeline_args, image=image, generator=generator).images[0]
13031299
images.append(image)
13041300

13051301
for tracker in accelerator.trackers:
@@ -1322,12 +1318,16 @@ def compute_text_embeddings(prompt):
13221318
# Save the lora layers
13231319
accelerator.wait_for_everyone()
13241320
if accelerator.is_main_process:
1321+
unet = accelerator.unwrap_model(unet)
13251322
unet = unet.to(torch.float32)
1326-
unet_lora_layers = accelerator.unwrap_model(unet_lora_layers)
1323+
unet_lora_layers = unet_attn_processors_state_dict(unet)
13271324

1328-
if text_encoder is not None:
1325+
if text_encoder is not None and args.train_text_encoder:
1326+
text_encoder = accelerator.unwrap_model(text_encoder)
13291327
text_encoder = text_encoder.to(torch.float32)
1330-
text_encoder_lora_layers = accelerator.unwrap_model(text_encoder_lora_layers)
1328+
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder)
1329+
else:
1330+
text_encoder_lora_layers = None
13311331

13321332
LoraLoaderMixin.save_lora_weights(
13331333
save_directory=args.output_dir,

0 commit comments

Comments
 (0)