diff --git a/scripts/convert_kandinsky_to_diffusers.py b/scripts/convert_kandinsky_to_diffusers.py index de9879f7f03b..1b5722f5d5f3 100644 --- a/scripts/convert_kandinsky_to_diffusers.py +++ b/scripts/convert_kandinsky_to_diffusers.py @@ -8,7 +8,6 @@ from diffusers import UNet2DConditionModel from diffusers.models.prior_transformer import PriorTransformer from diffusers.models.vq_model import VQModel -from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel """ @@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix UNET_CONFIG = { "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": (384, 768, 1152, 1536), + "block_out_channels": [384, 768, 1152, 1536], "center_input_sample": False, - "class_embed_type": "identity", + "class_embed_type": None, + "class_embeddings_concat": False, + "conv_in_kernel": 3, + "conv_out_kernel": 3, "cross_attention_dim": 768, - "down_block_types": ( + "cross_attention_norm": None, + "down_block_types": [ "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - ), + ], "downsample_padding": 1, "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 4, "layers_per_block": 3, + "mid_block_only_cross_attention": None, "mid_block_scale_factor": 1, "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_class_embeds": None, "only_cross_attention": False, "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, "resnet_time_scale_shift": "scale_shift", "sample_size": 64, - "up_block_types": ( + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D", - ), + ], "upcast_attention": False, "use_linear_projection": False, } @@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) # .input_blocks -> .down_blocks @@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): INPAINT_UNET_CONFIG = { "act_fn": "silu", + "addition_embed_type": "text_image", + "addition_embed_type_num_heads": 64, "attention_head_dim": 64, - "block_out_channels": (384, 768, 1152, 1536), + "block_out_channels": [384, 768, 1152, 1536], "center_input_sample": False, - "class_embed_type": "identity", + "class_embed_type": None, + "class_embeddings_concat": None, + "conv_in_kernel": 3, + "conv_out_kernel": 3, "cross_attention_dim": 768, - "down_block_types": ( + "cross_attention_norm": None, + "down_block_types": [ "ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", "SimpleCrossAttnDownBlock2D", - ), + ], "downsample_padding": 1, "dual_cross_attention": False, + "encoder_hid_dim": 1024, + "encoder_hid_dim_type": "text_image_proj", "flip_sin_to_cos": True, "freq_shift": 0, "in_channels": 9, "layers_per_block": 3, + "mid_block_only_cross_attention": None, "mid_block_scale_factor": 1, "mid_block_type": "UNetMidBlock2DSimpleCrossAttn", "norm_eps": 1e-05, "norm_num_groups": 32, + "num_class_embeds": None, "only_cross_attention": False, "out_channels": 8, + "projection_class_embeddings_input_dim": None, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": False, "resnet_time_scale_shift": "scale_shift", "sample_size": 64, - "up_block_types": ( + "time_cond_proj_dim": None, + "time_embedding_act_fn": None, + "time_embedding_dim": None, + "time_embedding_type": "positional", + "timestep_post_act": None, + "up_block_types": [ "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D", - ), + ], "upcast_attention": False, "use_linear_projection": False, } @@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config(): def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): diffusers_checkpoint = {} - num_head_channels = UNET_CONFIG["attention_head_dim"] + num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"] diffusers_checkpoint.update(unet_time_embeddings(checkpoint)) diffusers_checkpoint.update(unet_conv_in(checkpoint)) + diffusers_checkpoint.update(unet_add_embedding(checkpoint)) + diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint)) # .input_blocks -> .down_blocks @@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): # done inpaint unet -# text proj - -TEXT_PROJ_CONFIG = {} - - -def text_proj_from_original_config(): - model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG) - return model - - -# Note that the input checkpoint is the original text2img model checkpoint -def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint): - diffusers_checkpoint = { - # .text_seq_proj.0 -> .encoder_hidden_states_proj - "encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"], - "encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"], - # .clip_tok_proj -> .clip_extra_context_tokens_proj - "clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"], - "clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"], - # .proj_n -> .embedding_proj - "embedding_proj.weight": checkpoint["proj_n.weight"], - "embedding_proj.bias": checkpoint["proj_n.bias"], - # .ln_model_n -> .embedding_norm - "embedding_norm.weight": checkpoint["ln_model_n.weight"], - "embedding_norm.bias": checkpoint["ln_model_n.bias"], - # .clip_emb -> .clip_image_embeddings_project_to_time_embeddings - "clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"], - "clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"], - } - - return diffusers_checkpoint - # unet utils @@ -506,6 +513,38 @@ def unet_conv_in(checkpoint): return diffusers_checkpoint +def unet_add_embedding(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"], + "add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"], + "add_embedding.text_proj.weight": checkpoint["proj_n.weight"], + "add_embedding.text_proj.bias": checkpoint["proj_n.bias"], + "add_embedding.image_proj.weight": checkpoint["img_layer.weight"], + "add_embedding.image_proj.bias": checkpoint["img_layer.bias"], + } + ) + + return diffusers_checkpoint + + +def unet_encoder_hid_proj(checkpoint): + diffusers_checkpoint = {} + + diffusers_checkpoint.update( + { + "encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"], + "encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"], + "encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"], + "encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"], + } + ) + + return diffusers_checkpoint + + # .out.0 -> .conv_norm_out def unet_conv_norm_out(checkpoint): diffusers_checkpoint = {} @@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location): unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint) - # text proj interlude - - # The original decoder implementation includes a set of parameters that are used - # for creating the `encoder_hidden_states` which are what the U-net is conditioned - # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull - # the parameters into the KandinskyTextProjModel class - text_proj_model = text_proj_from_original_config() - - text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint) - - load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) - del text2img_checkpoint load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True) print("done loading text2img") - return unet_model, text_proj_model + return unet_model def inpaint_text2img(*, args, checkpoint_map_location): @@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location): inpaint_unet_model, inpaint_text2img_checkpoint ) - # text proj interlude - - # The original decoder implementation includes a set of parameters that are used - # for creating the `encoder_hidden_states` which are what the U-net is conditioned - # on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull - # the parameters into the KandinskyTextProjModel class - text_proj_model = text_proj_from_original_config() - - text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint) - - load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True) - del inpaint_text2img_checkpoint load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True) print("done loading inpaint text2img") - return inpaint_unet_model, text_proj_model + return inpaint_unet_model # movq @@ -1384,15 +1399,11 @@ def load_checkpoint_to_model(checkpoint, model, strict=False): prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location) prior_model.save_pretrained(args.dump_path) elif args.debug == "text2img": - unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) + unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location) unet_model.save_pretrained(f"{args.dump_path}/unet") - text_proj_model.save_pretrained(f"{args.dump_path}/text_proj") elif args.debug == "inpaint_text2img": - inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img( - args=args, checkpoint_map_location=checkpoint_map_location - ) + inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location) inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet") - inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj") elif args.debug == "decoder": decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location) decoder.save_pretrained(f"{args.dump_path}/decoder")