1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

update conversion script for Kandinsky unet (#3766)

* update kandinsky conversion script

* style

---------

Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
YiYi Xu
2023-06-14 06:57:53 -10:00
committed by GitHub
parent ce5504934a
commit 7761b89d7b

View File

@@ -8,7 +8,6 @@ from accelerate import load_checkpoint_and_dispatch
from diffusers import UNet2DConditionModel
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.vq_model import VQModel
from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel
"""
@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536),
"block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False,
"class_embed_type": "identity",
"class_embed_type": None,
"class_embeddings_concat": False,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768,
"down_block_types": (
"cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
),
],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 4,
"layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift",
"sample_size": 64,
"up_block_types": (
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D",
),
],
"upcast_attention": False,
"use_linear_projection": False,
}
@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
# <original>.input_blocks -> <diffusers>.down_blocks
@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
INPAINT_UNET_CONFIG = {
"act_fn": "silu",
"addition_embed_type": "text_image",
"addition_embed_type_num_heads": 64,
"attention_head_dim": 64,
"block_out_channels": (384, 768, 1152, 1536),
"block_out_channels": [384, 768, 1152, 1536],
"center_input_sample": False,
"class_embed_type": "identity",
"class_embed_type": None,
"class_embeddings_concat": None,
"conv_in_kernel": 3,
"conv_out_kernel": 3,
"cross_attention_dim": 768,
"down_block_types": (
"cross_attention_norm": None,
"down_block_types": [
"ResnetDownsampleBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
"SimpleCrossAttnDownBlock2D",
),
],
"downsample_padding": 1,
"dual_cross_attention": False,
"encoder_hid_dim": 1024,
"encoder_hid_dim_type": "text_image_proj",
"flip_sin_to_cos": True,
"freq_shift": 0,
"in_channels": 9,
"layers_per_block": 3,
"mid_block_only_cross_attention": None,
"mid_block_scale_factor": 1,
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"norm_eps": 1e-05,
"norm_num_groups": 32,
"num_class_embeds": None,
"only_cross_attention": False,
"out_channels": 8,
"projection_class_embeddings_input_dim": None,
"resnet_out_scale_factor": 1.0,
"resnet_skip_time_act": False,
"resnet_time_scale_shift": "scale_shift",
"sample_size": 64,
"up_block_types": (
"time_cond_proj_dim": None,
"time_embedding_act_fn": None,
"time_embedding_dim": None,
"time_embedding_type": "positional",
"timestep_post_act": None,
"up_block_types": [
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"SimpleCrossAttnUpBlock2D",
"ResnetUpsampleBlock2D",
),
],
"upcast_attention": False,
"use_linear_projection": False,
}
@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
num_head_channels = UNET_CONFIG["attention_head_dim"]
num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
diffusers_checkpoint.update(unet_conv_in(checkpoint))
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
# <original>.input_blocks -> <diffusers>.down_blocks
@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
# done inpaint unet
# text proj
TEXT_PROJ_CONFIG = {}
def text_proj_from_original_config():
model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG)
return model
# Note that the input checkpoint is the original text2img model checkpoint
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
diffusers_checkpoint = {
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
"encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"],
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
"clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"],
"clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"],
# <original>.proj_n -> <diffusers>.embedding_proj
"embedding_proj.weight": checkpoint["proj_n.weight"],
"embedding_proj.bias": checkpoint["proj_n.bias"],
# <original>.ln_model_n -> <diffusers>.embedding_norm
"embedding_norm.weight": checkpoint["ln_model_n.weight"],
"embedding_norm.bias": checkpoint["ln_model_n.bias"],
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"],
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"],
}
return diffusers_checkpoint
# unet utils
@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
return diffusers_checkpoint
def unet_add_embedding(checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(
{
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
}
)
return diffusers_checkpoint
def unet_encoder_hid_proj(checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(
{
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
}
)
return diffusers_checkpoint
# <original>.out.0 -> <diffusers>.conv_norm_out
def unet_conv_norm_out(checkpoint):
diffusers_checkpoint = {}
@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
# text proj interlude
# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint)
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
del text2img_checkpoint
load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
print("done loading text2img")
return unet_model, text_proj_model
return unet_model
def inpaint_text2img(*, args, checkpoint_map_location):
@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
inpaint_unet_model, inpaint_text2img_checkpoint
)
# text proj interlude
# The original decoder implementation includes a set of parameters that are used
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
# the parameters into the KandinskyTextProjModel class
text_proj_model = text_proj_from_original_config()
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint)
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
del inpaint_text2img_checkpoint
load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
print("done loading inpaint text2img")
return inpaint_unet_model, text_proj_model
return inpaint_unet_model
# movq
@@ -1384,15 +1399,11 @@ if __name__ == "__main__":
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
prior_model.save_pretrained(args.dump_path)
elif args.debug == "text2img":
unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
unet_model.save_pretrained(f"{args.dump_path}/unet")
text_proj_model.save_pretrained(f"{args.dump_path}/text_proj")
elif args.debug == "inpaint_text2img":
inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img(
args=args, checkpoint_map_location=checkpoint_map_location
)
inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj")
elif args.debug == "decoder":
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
decoder.save_pretrained(f"{args.dump_path}/decoder")