Skip to content

Commit fb98acf

Browse files
authored
[lora] Fix bug with training without validation (#2106)
1 parent 180841b commit fb98acf

File tree

1 file changed

+13
-13
lines changed

1 file changed

+13
-13
lines changed

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -984,19 +984,19 @@ def main(args):
984984
prompt = args.num_validation_images * [args.validation_prompt]
985985
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
986986

987-
for tracker in accelerator.trackers:
988-
if tracker.name == "tensorboard":
989-
np_images = np.stack([np.asarray(img) for img in images])
990-
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
991-
if tracker.name == "wandb":
992-
tracker.log(
993-
{
994-
"test": [
995-
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
996-
for i, image in enumerate(images)
997-
]
998-
}
999-
)
987+
for tracker in accelerator.trackers:
988+
if tracker.name == "tensorboard":
989+
np_images = np.stack([np.asarray(img) for img in images])
990+
tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC")
991+
if tracker.name == "wandb":
992+
tracker.log(
993+
{
994+
"test": [
995+
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
996+
for i, image in enumerate(images)
997+
]
998+
}
999+
)
10001000

10011001
if args.push_to_hub:
10021002
save_model_card(

0 commit comments

Comments
 (0)