Skip to content

Commit 8e69708

Browse files
authored
[Examples/DreamBooth] refactor save_model_card utility in dreambooth examples (#3543)
refactor save_model_card utility in dreambooth examples.
1 parent db56f8a commit 8e69708

File tree

2 files changed

+26
-7
lines changed

2 files changed

+26
-7
lines changed

examples/dreambooth/train_dreambooth.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
DDPMScheduler,
4747
DiffusionPipeline,
4848
DPMSolverMultistepScheduler,
49+
StableDiffusionPipeline,
4950
UNet2DConditionModel,
5051
)
5152
from diffusers.optimization import get_scheduler
@@ -62,7 +63,15 @@
6263
logger = get_logger(__name__)
6364

6465

65-
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
66+
def save_model_card(
67+
repo_id: str,
68+
images=None,
69+
base_model=str,
70+
train_text_encoder=False,
71+
prompt=str,
72+
repo_folder=None,
73+
pipeline: DiffusionPipeline = None,
74+
):
6675
img_str = ""
6776
for i, image in enumerate(images):
6877
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@@ -74,8 +83,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode
7483
base_model: {base_model}
7584
instance_prompt: {prompt}
7685
tags:
77-
- stable-diffusion
78-
- stable-diffusion-diffusers
86+
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
87+
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
7988
- text-to-image
8089
- diffusers
8190
- dreambooth
@@ -1297,6 +1306,7 @@ def compute_text_embeddings(prompt):
12971306
train_text_encoder=args.train_text_encoder,
12981307
prompt=args.instance_prompt,
12991308
repo_folder=args.output_dir,
1309+
pipeline=pipeline,
13001310
)
13011311
upload_folder(
13021312
repo_id=repo_id,

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,15 @@
6868
logger = get_logger(__name__)
6969

7070

71-
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
71+
def save_model_card(
72+
repo_id: str,
73+
images=None,
74+
base_model=str,
75+
train_text_encoder=False,
76+
prompt=str,
77+
repo_folder=None,
78+
pipeline: DiffusionPipeline = None,
79+
):
7280
img_str = ""
7381
for i, image in enumerate(images):
7482
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@@ -80,8 +88,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, train_text_encode
8088
base_model: {base_model}
8189
instance_prompt: {prompt}
8290
tags:
83-
- stable-diffusion
84-
- stable-diffusion-diffusers
91+
- {'stable-diffusion' if isinstance(pipeline, StableDiffusionPipeline) else 'if'}
92+
- {'stable-diffusion-diffusers' if isinstance(pipeline, StableDiffusionPipeline) else 'if-diffusers'}
8593
- text-to-image
8694
- diffusers
8795
- lora
@@ -844,7 +852,7 @@ def main(args):
844852
hidden_size=module.out_features, cross_attention_dim=None
845853
)
846854
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
847-
temp_pipeline = StableDiffusionPipeline.from_pretrained(
855+
temp_pipeline = DiffusionPipeline.from_pretrained(
848856
args.pretrained_model_name_or_path, text_encoder=text_encoder
849857
)
850858
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
@@ -1332,6 +1340,7 @@ def compute_text_embeddings(prompt):
13321340
train_text_encoder=args.train_text_encoder,
13331341
prompt=args.instance_prompt,
13341342
repo_folder=args.output_dir,
1343+
pipeline=pipeline,
13351344
)
13361345
upload_folder(
13371346
repo_id=repo_id,

0 commit comments

Comments
 (0)