Skip to content

Commit 68dd140

Browse files
committed
refactor all prompt embedding code
class prompts are now included in pre-encoding code max tokenizer length is now configurable embedding attention mask is now configurable
1 parent b86ddc4 commit 68dd140

File tree

1 file changed

+96
-31
lines changed

1 file changed

+96
-31
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 96 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,19 @@ def parse_args(input_args=None):
403403
action="store_true",
404404
help="Whether or not to pre-compute text embeddings. If text embeddings are pre-computed, the text encoder will not be kept in memory during training and will leave more GPU memory available for training the rest of the model. This is not compatible with `--train_text_encoder`.",
405405
)
406+
parser.add_argument(
407+
"--tokenizer_max_length",
408+
type=int,
409+
default=None,
410+
required=False,
411+
help="The maximum length of the tokenizer. If not set, will default to the tokenizer's max length.",
412+
)
413+
parser.add_argument(
414+
"--text_encoder_use_attention_mask",
415+
action="store_true",
416+
required=False,
417+
help="Whether to use attention mask for the text encoder",
418+
)
406419

407420
if input_args is not None:
408421
args = parser.parse_args(input_args)
@@ -445,11 +458,15 @@ def __init__(
445458
size=512,
446459
center_crop=False,
447460
encoder_hidden_states=None,
461+
instance_prompt_encoder_hidden_states=None,
462+
tokenizer_max_length=None,
448463
):
449464
self.size = size
450465
self.center_crop = center_crop
451466
self.tokenizer = tokenizer
452467
self.encoder_hidden_states = encoder_hidden_states
468+
self.instance_prompt_encoder_hidden_states = instance_prompt_encoder_hidden_states
469+
self.tokenizer_max_length = tokenizer_max_length
453470

454471
self.instance_data_root = Path(instance_data_root)
455472
if not self.instance_data_root.exists():
@@ -495,39 +512,46 @@ def __getitem__(self, index):
495512
if self.encoder_hidden_states is not None:
496513
example["instance_prompt_ids"] = self.encoder_hidden_states
497514
else:
498-
example["instance_prompt_ids"] = self.tokenizer(
499-
self.instance_prompt,
500-
truncation=True,
501-
padding="max_length",
502-
max_length=self.tokenizer.model_max_length,
503-
return_tensors="pt",
504-
).input_ids
515+
text_inputs = tokenize_prompt(
516+
self.tokenizer, self.instance_prompt, tokenizer_max_length=self.tokenizer_max_length
517+
)
518+
example["instance_prompt_ids"] = text_inputs.input_ids
519+
example["instance_attention_mask"] = text_inputs.attention_mask
505520

506521
if self.class_data_root:
507522
class_image = Image.open(self.class_images_path[index % self.num_class_images])
508523
if not class_image.mode == "RGB":
509524
class_image = class_image.convert("RGB")
510525
example["class_images"] = self.image_transforms(class_image)
511-
example["class_prompt_ids"] = self.tokenizer(
512-
self.class_prompt,
513-
truncation=True,
514-
padding="max_length",
515-
max_length=self.tokenizer.model_max_length,
516-
return_tensors="pt",
517-
).input_ids
526+
527+
if self.instance_prompt_encoder_hidden_states is not None:
528+
example["class_prompt_ids"] = self.instance_prompt_encoder_hidden_states
529+
else:
530+
class_text_inputs = tokenize_prompt(
531+
self.tokenizer, self.class_prompt, tokenizer_max_length=self.tokenizer_max_length
532+
)
533+
example["class_prompt_ids"] = class_text_inputs.input_ids
534+
example["class_attention_mask"] = class_text_inputs.attention_mask
518535

519536
return example
520537

521538

522539
def collate_fn(examples, with_prior_preservation=False):
540+
has_attention_mask = "instance_attention_mask" in examples[0]
541+
523542
input_ids = [example["instance_prompt_ids"] for example in examples]
524543
pixel_values = [example["instance_images"] for example in examples]
525544

545+
if has_attention_mask:
546+
attention_mask = [example["instance_attention_mask"] for example in examples]
547+
526548
# Concat class and instance examples for prior preservation.
527549
# We do this to avoid doing two forward passes.
528550
if with_prior_preservation:
529551
input_ids += [example["class_prompt_ids"] for example in examples]
530552
pixel_values += [example["class_images"] for example in examples]
553+
if has_attention_mask:
554+
attention_mask += [example["class_attention_mask"] for example in examples]
531555

532556
pixel_values = torch.stack(pixel_values)
533557
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
@@ -538,6 +562,10 @@ def collate_fn(examples, with_prior_preservation=False):
538562
"input_ids": input_ids,
539563
"pixel_values": pixel_values,
540564
}
565+
566+
if has_attention_mask:
567+
batch["attention_mask"] = attention_mask
568+
541569
return batch
542570

543571

@@ -568,6 +596,40 @@ def model_has_vae(args):
568596
return any(file.rfilename == config_file_name for file in files_in_repo)
569597

570598

599+
def tokenize_prompt(tokenizer, prompt, tokenizer_max_length=None):
600+
if tokenizer_max_length is not None:
601+
max_length = tokenizer_max_length
602+
else:
603+
max_length = tokenizer.model_max_length
604+
605+
text_inputs = tokenizer(
606+
prompt,
607+
truncation=True,
608+
padding="max_length",
609+
max_length=max_length,
610+
return_tensors="pt",
611+
)
612+
613+
return text_inputs
614+
615+
616+
def encode_prompt(text_encoder, input_ids, attention_mask, text_encoder_use_attention_mask=None):
617+
text_input_ids = input_ids.to(text_encoder.device)
618+
619+
if text_encoder_use_attention_mask:
620+
attention_mask = attention_mask.to(text_encoder.device)
621+
else:
622+
attention_mask = None
623+
624+
prompt_embeds = text_encoder(
625+
text_input_ids,
626+
attention_mask=attention_mask,
627+
)
628+
prompt_embeds = prompt_embeds[0]
629+
630+
return prompt_embeds
631+
632+
571633
def main(args):
572634
logging_dir = Path(args.output_dir, args.logging_dir)
573635

@@ -832,30 +894,25 @@ def main(args):
832894

833895
def compute_text_embeddings(prompt):
834896
with torch.no_grad():
835-
text_inputs = tokenizer(
836-
prompt,
837-
padding="max_length",
838-
max_length=77,
839-
truncation=True,
840-
add_special_tokens=True,
841-
return_tensors="pt",
842-
)
843-
844-
text_input_ids = text_inputs.input_ids
845-
attention_mask = text_inputs.attention_mask.to(text_encoder.device)
846-
847-
prompt_embeds = text_encoder(
848-
text_input_ids.to(text_encoder.device),
849-
attention_mask=attention_mask,
897+
text_inputs = tokenize_prompt(tokenizer, prompt, tokenizer_max_length=args.tokenizer_max_length)
898+
prompt_embeds = encode_prompt(
899+
text_encoder,
900+
text_inputs.input_ids,
901+
text_inputs.attention_mask,
902+
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
850903
)
851-
prompt_embeds = prompt_embeds[0]
852904

853905
return prompt_embeds
854906

855907
pre_computed_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
856908
validation_prompt_encoder_hidden_states = compute_text_embeddings(args.validation_prompt)
857909
validation_prompt_negative_prompt_embeds = compute_text_embeddings("")
858910

911+
if args.instance_prompt is not None:
912+
pre_computed_instance_prompt_encoder_hidden_states = compute_text_embeddings(args.instance_prompt)
913+
else:
914+
pre_computed_instance_prompt_encoder_hidden_states = None
915+
859916
text_encoder = None
860917
tokenizer = None
861918

@@ -865,6 +922,7 @@ def compute_text_embeddings(prompt):
865922
pre_computed_encoder_hidden_states = None
866923
validation_prompt_encoder_hidden_states = None
867924
validation_prompt_negative_prompt_embeds = None
925+
pre_computed_instance_prompt_encoder_hidden_states = None
868926

869927
# Dataset and DataLoaders creation:
870928
train_dataset = DreamBoothDataset(
@@ -877,6 +935,8 @@ def compute_text_embeddings(prompt):
877935
size=args.resolution,
878936
center_crop=args.center_crop,
879937
encoder_hidden_states=pre_computed_encoder_hidden_states,
938+
instance_prompt_hidden_states=pre_computed_instance_prompt_encoder_hidden_states,
939+
tokenizer_max_length=args.tokenizer_max_length,
880940
)
881941

882942
train_dataloader = torch.utils.data.DataLoader(
@@ -1006,7 +1066,12 @@ def compute_text_embeddings(prompt):
10061066
if args.pre_compute_text_embeddings:
10071067
encoder_hidden_states = batch["input_ids"]
10081068
else:
1009-
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
1069+
encoder_hidden_states = encode_prompt(
1070+
text_encoder,
1071+
batch["input_ids"],
1072+
batch["attention_mask"],
1073+
text_encoder_use_attention_mask=args.text_encoder_use_attention_mask,
1074+
)
10101075

10111076
# Predict the noise residual
10121077
model_pred = unet(noisy_model_input, timesteps, encoder_hidden_states).sample

0 commit comments

Comments
 (0)