@@ -403,6 +403,19 @@ def parse_args(input_args=None):
403
403
action = "store_true" ,
404
404
help = "Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`." ,
405
405
)
406
+ parser .add_argument (
407
+ "--tokenizer_max_length" ,
408
+ type = int ,
409
+ default = None ,
410
+ required = False ,
411
+ help = "The maximum length of the tokenizer. If not set, will default to the tokenizer's max length." ,
412
+ )
413
+ parser .add_argument (
414
+ "--text_encoder_use_attention_mask" ,
415
+ action = "store_true" ,
416
+ required = False ,
417
+ help = "Whether to use attention mask for the text encoder" ,
418
+ )
406
419
407
420
if input_args is not None :
408
421
args = parser .parse_args (input_args )
@@ -445,11 +458,15 @@ def __init__(
445
458
size = 512 ,
446
459
center_crop = False ,
447
460
encoder_hidden_states = None ,
461
+ instance_prompt_encoder_hidden_states = None ,
462
+ tokenizer_max_length = None ,
448
463
):
449
464
self .size = size
450
465
self .center_crop = center_crop
451
466
self .tokenizer = tokenizer
452
467
self .encoder_hidden_states = encoder_hidden_states
468
+ self .instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
469
+ self .tokenizer_max_length = tokenizer_max_length
453
470
454
471
self .instance_data_root = Path (instance_data_root )
455
472
if not self .instance_data_root .exists ():
@@ -495,39 +512,46 @@ def __getitem__(self, index):
495
512
if self .encoder_hidden_states is not None :
496
513
example ["instance_prompt_ids" ] = self .encoder_hidden_states
497
514
else :
498
- example ["instance_prompt_ids" ] = self .tokenizer (
499
- self .instance_prompt ,
500
- truncation = True ,
501
- padding = "max_length" ,
502
- max_length = self .tokenizer .model_max_length ,
503
- return_tensors = "pt" ,
504
- ).input_ids
515
+ text_inputs = tokenize_prompt (
516
+ self .tokenizer , self .instance_prompt , tokenizer_max_length = self .tokenizer_max_length
517
+ )
518
+ example ["instance_prompt_ids" ] = text_inputs .input_ids
519
+ example ["instance_attention_mask" ] = text_inputs .attention_mask
505
520
506
521
if self .class_data_root :
507
522
class_image = Image .open (self .class_images_path [index % self .num_class_images ])
508
523
if not class_image .mode == "RGB" :
509
524
class_image = class_image .convert ("RGB" )
510
525
example ["class_images" ] = self .image_transforms (class_image )
511
- example ["class_prompt_ids" ] = self .tokenizer (
512
- self .class_prompt ,
513
- truncation = True ,
514
- padding = "max_length" ,
515
- max_length = self .tokenizer .model_max_length ,
516
- return_tensors = "pt" ,
517
- ).input_ids
526
+
527
+ if self .instance_prompt_encoder_hidden_states is not None :
528
+ example ["class_prompt_ids" ] = self .instance_prompt_encoder_hidden_states
529
+ else :
530
+ class_text_inputs = tokenize_prompt (
531
+ self .tokenizer , self .class_prompt , tokenizer_max_length = self .tokenizer_max_length
532
+ )
533
+ example ["class_prompt_ids" ] = class_text_inputs .input_ids
534
+ example ["class_attention_mask" ] = class_text_inputs .attention_mask
518
535
519
536
return example
520
537
521
538
522
539
def collate_fn (examples , with_prior_preservation = False ):
540
+ has_attention_mask = "instance_attention_mask" in examples [0 ]
541
+
523
542
input_ids = [example ["instance_prompt_ids" ] for example in examples ]
524
543
pixel_values = [example ["instance_images" ] for example in examples ]
525
544
545
+ if has_attention_mask :
546
+ attention_mask = [example ["instance_attention_mask" ] for example in examples ]
547
+
526
548
# Concat class and instance examples for prior preservation.
527
549
# We do this to avoid doing two forward passes.
528
550
if with_prior_preservation :
529
551
input_ids += [example ["class_prompt_ids" ] for example in examples ]
530
552
pixel_values += [example ["class_images" ] for example in examples ]
553
+ if has_attention_mask :
554
+ attention_mask += [example ["class_attention_mask" ] for example in examples ]
531
555
532
556
pixel_values = torch .stack (pixel_values )
533
557
pixel_values = pixel_values .to (memory_format = torch .contiguous_format ).float ()
@@ -538,6 +562,10 @@ def collate_fn(examples, with_prior_preservation=False):
538
562
"input_ids" : input_ids ,
539
563
"pixel_values" : pixel_values ,
540
564
}
565
+
566
+ if has_attention_mask :
567
+ batch ["attention_mask" ] = attention_mask
568
+
541
569
return batch
542
570
543
571
@@ -568,6 +596,40 @@ def model_has_vae(args):
568
596
return any (file .rfilename == config_file_name for file in files_in_repo )
569
597
570
598
599
+ def tokenize_prompt (tokenizer , prompt , tokenizer_max_length = None ):
600
+ if tokenizer_max_length is not None :
601
+ max_length = tokenizer_max_length
602
+ else :
603
+ max_length = tokenizer .model_max_length
604
+
605
+ text_inputs = tokenizer (
606
+ prompt ,
607
+ truncation = True ,
608
+ padding = "max_length" ,
609
+ max_length = max_length ,
610
+ return_tensors = "pt" ,
611
+ )
612
+
613
+ return text_inputs
614
+
615
+
616
+ def encode_prompt (text_encoder , input_ids , attention_mask , text_encoder_use_attention_mask = None ):
617
+ text_input_ids = input_ids .to (text_encoder .device )
618
+
619
+ if text_encoder_use_attention_mask :
620
+ attention_mask = attention_mask .to (text_encoder .device )
621
+ else :
622
+ attention_mask = None
623
+
624
+ prompt_embeds = text_encoder (
625
+ text_input_ids ,
626
+ attention_mask = attention_mask ,
627
+ )
628
+ prompt_embeds = prompt_embeds [0 ]
629
+
630
+ return prompt_embeds
631
+
632
+
571
633
def main (args ):
572
634
logging_dir = Path (args .output_dir , args .logging_dir )
573
635
@@ -832,30 +894,25 @@ def main(args):
832
894
833
895
def compute_text_embeddings (prompt ):
834
896
with torch .no_grad ():
835
- text_inputs = tokenizer (
836
- prompt ,
837
- padding = "max_length" ,
838
- max_length = 77 ,
839
- truncation = True ,
840
- add_special_tokens = True ,
841
- return_tensors = "pt" ,
842
- )
843
-
844
- text_input_ids = text_inputs .input_ids
845
- attention_mask = text_inputs .attention_mask .to (text_encoder .device )
846
-
847
- prompt_embeds = text_encoder (
848
- text_input_ids .to (text_encoder .device ),
849
- attention_mask = attention_mask ,
897
+ text_inputs = tokenize_prompt (tokenizer , prompt , tokenizer_max_length = args .tokenizer_max_length )
898
+ prompt_embeds = encode_prompt (
899
+ text_encoder ,
900
+ text_inputs .input_ids ,
901
+ text_inputs .attention_mask ,
902
+ text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
850
903
)
851
- prompt_embeds = prompt_embeds [0 ]
852
904
853
905
return prompt_embeds
854
906
855
907
pre_computed_encoder_hidden_states = compute_text_embeddings (args .instance_prompt )
856
908
validation_prompt_encoder_hidden_states = compute_text_embeddings (args .validation_prompt )
857
909
validation_prompt_negative_prompt_embeds = compute_text_embeddings ("" )
858
910
911
+ if args .instance_prompt is not None :
912
+ pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings (args .instance_prompt )
913
+ else :
914
+ pre_computed_instance_prompt_encoder_hidden_states = None
915
+
859
916
text_encoder = None
860
917
tokenizer = None
861
918
@@ -865,6 +922,7 @@ def compute_text_embeddings(prompt):
865
922
pre_computed_encoder_hidden_states = None
866
923
validation_prompt_encoder_hidden_states = None
867
924
validation_prompt_negative_prompt_embeds = None
925
+ pre_computed_instance_prompt_encoder_hidden_states = None
868
926
869
927
# Dataset and DataLoaders creation:
870
928
train_dataset = DreamBoothDataset (
@@ -877,6 +935,8 @@ def compute_text_embeddings(prompt):
877
935
size = args .resolution ,
878
936
center_crop = args .center_crop ,
879
937
encoder_hidden_states = pre_computed_encoder_hidden_states ,
938
+ instance_prompt_hidden_states = pre_computed_instance_prompt_encoder_hidden_states ,
939
+ tokenizer_max_length = args .tokenizer_max_length ,
880
940
)
881
941
882
942
train_dataloader = torch .utils .data .DataLoader (
@@ -1006,7 +1066,12 @@ def compute_text_embeddings(prompt):
1006
1066
if args .pre_compute_text_embeddings :
1007
1067
encoder_hidden_states = batch ["input_ids" ]
1008
1068
else :
1009
- encoder_hidden_states = text_encoder (batch ["input_ids" ])[0 ]
1069
+ encoder_hidden_states = encode_prompt (
1070
+ text_encoder ,
1071
+ batch ["input_ids" ],
1072
+ batch ["attention_mask" ],
1073
+ text_encoder_use_attention_mask = args .text_encoder_use_attention_mask ,
1074
+ )
1010
1075
1011
1076
# Predict the noise residual
1012
1077
model_pred = unet (noisy_model_input , timesteps , encoder_hidden_states ).sample
0 commit comments