23
23
import shutil
24
24
import warnings
25
25
from pathlib import Path
26
+ from typing import Dict
26
27
27
28
import numpy as np
28
29
import torch
50
51
StableDiffusionPipeline ,
51
52
UNet2DConditionModel ,
52
53
)
53
- from diffusers .loaders import AttnProcsLayers , LoraLoaderMixin
54
+ from diffusers .loaders import (
55
+ LoraLoaderMixin ,
56
+ text_encoder_lora_state_dict ,
57
+ )
54
58
from diffusers .models .attention_processor import (
55
59
AttnAddedKVProcessor ,
56
60
AttnAddedKVProcessor2_0 ,
60
64
SlicedAttnAddedKVProcessor ,
61
65
)
62
66
from diffusers .optimization import get_scheduler
63
- from diffusers .utils import TEXT_ENCODER_ATTN_MODULE , check_min_version , is_wandb_available
67
+ from diffusers .utils import check_min_version , is_wandb_available
64
68
from diffusers .utils .import_utils import is_xformers_available
65
69
66
70
@@ -647,6 +651,22 @@ def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_atte
647
651
return prompt_embeds
648
652
649
653
654
+ def unet_attn_processors_state_dict (unet ) -> Dict [str , torch .tensor ]:
655
+ r"""
656
+ Returns:
657
+ a state dict containing just the attention processor parameters.
658
+ """
659
+ attn_processors = unet .attn_processors
660
+
661
+ attn_processors_state_dict = {}
662
+
663
+ for attn_processor_key , attn_processor in attn_processors .items ():
664
+ for parameter_key , parameter in attn_processor .state_dict ().items ():
665
+ attn_processors_state_dict [f"{ attn_processor_key } .{ parameter_key } " ] = parameter
666
+
667
+ return attn_processors_state_dict
668
+
669
+
650
670
def main (args ):
651
671
logging_dir = Path (args .output_dir , args .logging_dir )
652
672
@@ -827,6 +847,7 @@ def main(args):
827
847
828
848
# Set correct lora layers
829
849
unet_lora_attn_procs = {}
850
+ unet_lora_parameters = []
830
851
for name , attn_processor in unet .attn_processors .items ():
831
852
cross_attention_dim = None if name .endswith ("attn1.processor" ) else unet .config .cross_attention_dim
832
853
if name .startswith ("mid_block" ):
@@ -844,31 +865,17 @@ def main(args):
844
865
lora_attn_processor_class = (
845
866
LoRAAttnProcessor2_0 if hasattr (F , "scaled_dot_product_attention" ) else LoRAAttnProcessor
846
867
)
847
- unet_lora_attn_procs [ name ] = lora_attn_processor_class (
848
- hidden_size = hidden_size , cross_attention_dim = cross_attention_dim
849
- )
868
+ module = lora_attn_processor_class (hidden_size = hidden_size , cross_attention_dim = cross_attention_dim )
869
+ unet_lora_attn_procs [ name ] = module
870
+ unet_lora_parameters . extend ( module . parameters () )
850
871
851
872
unet .set_attn_processor (unet_lora_attn_procs )
852
- unet_lora_layers = AttnProcsLayers (unet .attn_processors )
853
873
854
874
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
855
- # So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
856
- # we first load a dummy pipeline with the text encoder and then do the monkey-patching.
857
- text_encoder_lora_layers = None
875
+ # So, instead, we monkey-patch the forward calls of its attention-blocks.
858
876
if args .train_text_encoder :
859
- text_lora_attn_procs = {}
860
- for name , module in text_encoder .named_modules ():
861
- if name .endswith (TEXT_ENCODER_ATTN_MODULE ):
862
- text_lora_attn_procs [name ] = LoRAAttnProcessor (
863
- hidden_size = module .out_proj .out_features , cross_attention_dim = None
864
- )
865
- text_encoder_lora_layers = AttnProcsLayers (text_lora_attn_procs )
866
- temp_pipeline = DiffusionPipeline .from_pretrained (
867
- args .pretrained_model_name_or_path , text_encoder = text_encoder
868
- )
869
- temp_pipeline ._modify_text_encoder (text_lora_attn_procs )
870
- text_encoder = temp_pipeline .text_encoder
871
- del temp_pipeline
877
+ # ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
878
+ text_lora_parameters = LoraLoaderMixin ._modify_text_encoder (text_encoder , dtype = torch .float32 )
872
879
873
880
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
874
881
def save_model_hook (models , weights , output_dir ):
@@ -877,23 +884,13 @@ def save_model_hook(models, weights, output_dir):
877
884
unet_lora_layers_to_save = None
878
885
text_encoder_lora_layers_to_save = None
879
886
880
- if args .train_text_encoder :
881
- text_encoder_keys = accelerator .unwrap_model (text_encoder_lora_layers ).state_dict ().keys ()
882
- unet_keys = accelerator .unwrap_model (unet_lora_layers ).state_dict ().keys ()
883
-
884
887
for model in models :
885
- state_dict = model .state_dict ()
886
-
887
- if (
888
- text_encoder_lora_layers is not None
889
- and text_encoder_keys is not None
890
- and state_dict .keys () == text_encoder_keys
891
- ):
892
- # text encoder
893
- text_encoder_lora_layers_to_save = state_dict
894
- elif state_dict .keys () == unet_keys :
895
- # unet
896
- unet_lora_layers_to_save = state_dict
888
+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
889
+ unet_lora_layers_to_save = unet_attn_processors_state_dict (model )
890
+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
891
+ text_encoder_lora_layers_to_save = text_encoder_lora_state_dict (model )
892
+ else :
893
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
897
894
898
895
# make sure to pop weight so that corresponding model is not saved again
899
896
weights .pop ()
@@ -905,27 +902,24 @@ def save_model_hook(models, weights, output_dir):
905
902
)
906
903
907
904
def load_model_hook (models , input_dir ):
908
- # Note we DON'T pass the unet and text encoder here an purpose
909
- # so that the we don't accidentally override the LoRA layers of
910
- # unet_lora_layers and text_encoder_lora_layers which are stored in `models`
911
- # with new torch.nn.Modules / weights. We simply use the pipeline class as
912
- # an easy way to load the lora checkpoints
913
- temp_pipeline = DiffusionPipeline .from_pretrained (
914
- args .pretrained_model_name_or_path ,
915
- revision = args .revision ,
916
- torch_dtype = weight_dtype ,
917
- )
918
- temp_pipeline .load_lora_weights (input_dir )
905
+ unet_ = None
906
+ text_encoder_ = None
919
907
920
- # load lora weights into models
921
- models [0 ].load_state_dict (AttnProcsLayers (temp_pipeline .unet .attn_processors ).state_dict ())
922
- if len (models ) > 1 :
923
- models [1 ].load_state_dict (AttnProcsLayers (temp_pipeline .text_encoder_lora_attn_procs ).state_dict ())
908
+ while len (models ) > 0 :
909
+ model = models .pop ()
924
910
925
- # delete temporary pipeline and pop models
926
- del temp_pipeline
927
- for _ in range (len (models )):
928
- models .pop ()
911
+ if isinstance (model , type (accelerator .unwrap_model (unet ))):
912
+ unet_ = model
913
+ elif isinstance (model , type (accelerator .unwrap_model (text_encoder ))):
914
+ text_encoder_ = model
915
+ else :
916
+ raise ValueError (f"unexpected save model: { model .__class__ } " )
917
+
918
+ lora_state_dict , network_alpha = LoraLoaderMixin .lora_state_dict (input_dir )
919
+ LoraLoaderMixin .load_lora_into_unet (lora_state_dict , network_alpha = network_alpha , unet = unet_ )
920
+ LoraLoaderMixin .load_lora_into_text_encoder (
921
+ lora_state_dict , network_alpha = network_alpha , text_encoder = text_encoder_
922
+ )
929
923
930
924
accelerator .register_save_state_pre_hook (save_model_hook )
931
925
accelerator .register_load_state_pre_hook (load_model_hook )
@@ -955,9 +949,9 @@ def load_model_hook(models, input_dir):
955
949
956
950
# Optimizer creation
957
951
params_to_optimize = (
958
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
952
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
959
953
if args .train_text_encoder
960
- else unet_lora_layers . parameters ()
954
+ else unet_lora_parameters
961
955
)
962
956
optimizer = optimizer_class (
963
957
params_to_optimize ,
@@ -1046,12 +1040,12 @@ def compute_text_embeddings(prompt):
1046
1040
1047
1041
# Prepare everything with our `accelerator`.
1048
1042
if args .train_text_encoder :
1049
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1050
- unet_lora_layers , text_encoder_lora_layers , optimizer , train_dataloader , lr_scheduler
1043
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1044
+ unet , text_encoder , optimizer , train_dataloader , lr_scheduler
1051
1045
)
1052
1046
else :
1053
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1054
- unet_lora_layers , optimizer , train_dataloader , lr_scheduler
1047
+ unet , optimizer , train_dataloader , lr_scheduler = accelerator .prepare (
1048
+ unet , optimizer , train_dataloader , lr_scheduler
1055
1049
)
1056
1050
1057
1051
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
@@ -1200,9 +1194,9 @@ def compute_text_embeddings(prompt):
1200
1194
accelerator .backward (loss )
1201
1195
if accelerator .sync_gradients :
1202
1196
params_to_clip = (
1203
- itertools .chain (unet_lora_layers . parameters (), text_encoder_lora_layers . parameters () )
1197
+ itertools .chain (unet_lora_parameters , text_lora_parameters )
1204
1198
if args .train_text_encoder
1205
- else unet_lora_layers . parameters ()
1199
+ else unet_lora_parameters
1206
1200
)
1207
1201
accelerator .clip_grad_norm_ (params_to_clip , args .max_grad_norm )
1208
1202
optimizer .step ()
@@ -1291,15 +1285,17 @@ def compute_text_embeddings(prompt):
1291
1285
pipeline_args = {"prompt" : args .validation_prompt }
1292
1286
1293
1287
if args .validation_images is None :
1294
- images = [
1295
- pipeline (** pipeline_args , generator = generator ).images [0 ]
1296
- for _ in range (args .num_validation_images )
1297
- ]
1288
+ images = []
1289
+ for _ in range (args .num_validation_images ):
1290
+ with torch .cuda .amp .autocast ():
1291
+ image = pipeline (** pipeline_args , generator = generator ).images [0 ]
1292
+ images .append (image )
1298
1293
else :
1299
1294
images = []
1300
1295
for image in args .validation_images :
1301
1296
image = Image .open (image )
1302
- image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
1297
+ with torch .cuda .amp .autocast ():
1298
+ image = pipeline (** pipeline_args , image = image , generator = generator ).images [0 ]
1303
1299
images .append (image )
1304
1300
1305
1301
for tracker in accelerator .trackers :
@@ -1322,12 +1318,16 @@ def compute_text_embeddings(prompt):
1322
1318
# Save the lora layers
1323
1319
accelerator .wait_for_everyone ()
1324
1320
if accelerator .is_main_process :
1321
+ unet = accelerator .unwrap_model (unet )
1325
1322
unet = unet .to (torch .float32 )
1326
- unet_lora_layers = accelerator . unwrap_model ( unet_lora_layers )
1323
+ unet_lora_layers = unet_attn_processors_state_dict ( unet )
1327
1324
1328
- if text_encoder is not None :
1325
+ if text_encoder is not None and args .train_text_encoder :
1326
+ text_encoder = accelerator .unwrap_model (text_encoder )
1329
1327
text_encoder = text_encoder .to (torch .float32 )
1330
- text_encoder_lora_layers = accelerator .unwrap_model (text_encoder_lora_layers )
1328
+ text_encoder_lora_layers = text_encoder_lora_state_dict (text_encoder )
1329
+ else :
1330
+ text_encoder_lora_layers = None
1331
1331
1332
1332
LoraLoaderMixin .save_lora_weights (
1333
1333
save_directory = args .output_dir ,
0 commit comments