diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index c0b52291fc9b..d0e472a57ed7 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -60,6 +60,17 @@ logger = get_logger(__name__) +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, accelerator, weight_dtype, step): logger.info("Running validation... ") @@ -156,6 +167,8 @@ def log_validation(vae, text_encoder, tokenizer, unet, controlnet, args, acceler else: logger.warn(f"image logging not implemented for {tracker.name}") + return image_logs + def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): text_encoder_config = PretrainedConfig.from_pretrained( @@ -177,6 +190,43 @@ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: st raise ValueError(f"{model_class} is not supported.") +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + yaml = f""" +--- +license: creativeml-openrail-m +base_model: {base_model} +tags: +- stable-diffusion +- stable-diffusion-diffusers +- text-to-image +- diffusers +- controlnet +inference: true +--- + """ + model_card = f""" +# controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +{img_str} +""" + with open(os.path.join(repo_folder, "README.md"), "w") as f: + f.write(yaml + model_card) + + def parse_args(input_args=None): parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") parser.add_argument( @@ -936,6 +986,7 @@ def load_model_hook(models, input_dir): disable=not accelerator.is_local_main_process, ) + image_logs = None for epoch in range(first_epoch, args.num_train_epochs): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(controlnet): @@ -1007,7 +1058,7 @@ def load_model_hook(models, input_dir): logger.info(f"Saved state to {save_path}") if args.validation_prompt is not None and global_step % args.validation_steps == 0: - log_validation( + image_logs = log_validation( vae, text_encoder, tokenizer, @@ -1033,6 +1084,12 @@ def load_model_hook(models, input_dir): controlnet.save_pretrained(args.output_dir) if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) upload_folder( repo_id=repo_id, folder_path=args.output_dir,