mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
update conversion script for Kandinsky unet (#3766)
* update kandinsky conversion script * style --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
This commit is contained in:
@@ -8,7 +8,6 @@ from accelerate import load_checkpoint_and_dispatch
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.prior_transformer import PriorTransformer
|
||||
from diffusers.models.vq_model import VQModel
|
||||
from diffusers.pipelines.kandinsky.text_proj import KandinskyTextProjModel
|
||||
|
||||
|
||||
"""
|
||||
@@ -225,37 +224,55 @@ def prior_ff_to_diffusers(checkpoint, *, diffusers_ff_prefix, original_ff_prefix
|
||||
|
||||
UNET_CONFIG = {
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_image",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"attention_head_dim": 64,
|
||||
"block_out_channels": (384, 768, 1152, 1536),
|
||||
"block_out_channels": [384, 768, 1152, 1536],
|
||||
"center_input_sample": False,
|
||||
"class_embed_type": "identity",
|
||||
"class_embed_type": None,
|
||||
"class_embeddings_concat": False,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": (
|
||||
"cross_attention_norm": None,
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
),
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": False,
|
||||
"encoder_hid_dim": 1024,
|
||||
"encoder_hid_dim_type": "text_image_proj",
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 3,
|
||||
"mid_block_only_cross_attention": None,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 8,
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": False,
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"sample_size": 64,
|
||||
"up_block_types": (
|
||||
"time_cond_proj_dim": None,
|
||||
"time_embedding_act_fn": None,
|
||||
"time_embedding_dim": None,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": None,
|
||||
"up_block_types": [
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
),
|
||||
],
|
||||
"upcast_attention": False,
|
||||
"use_linear_projection": False,
|
||||
}
|
||||
@@ -274,6 +291,8 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
|
||||
diffusers_checkpoint.update(unet_conv_in(checkpoint))
|
||||
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
|
||||
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
|
||||
|
||||
# <original>.input_blocks -> <diffusers>.down_blocks
|
||||
|
||||
@@ -336,37 +355,55 @@ def unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
INPAINT_UNET_CONFIG = {
|
||||
"act_fn": "silu",
|
||||
"addition_embed_type": "text_image",
|
||||
"addition_embed_type_num_heads": 64,
|
||||
"attention_head_dim": 64,
|
||||
"block_out_channels": (384, 768, 1152, 1536),
|
||||
"block_out_channels": [384, 768, 1152, 1536],
|
||||
"center_input_sample": False,
|
||||
"class_embed_type": "identity",
|
||||
"class_embed_type": None,
|
||||
"class_embeddings_concat": None,
|
||||
"conv_in_kernel": 3,
|
||||
"conv_out_kernel": 3,
|
||||
"cross_attention_dim": 768,
|
||||
"down_block_types": (
|
||||
"cross_attention_norm": None,
|
||||
"down_block_types": [
|
||||
"ResnetDownsampleBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
"SimpleCrossAttnDownBlock2D",
|
||||
),
|
||||
],
|
||||
"downsample_padding": 1,
|
||||
"dual_cross_attention": False,
|
||||
"encoder_hid_dim": 1024,
|
||||
"encoder_hid_dim_type": "text_image_proj",
|
||||
"flip_sin_to_cos": True,
|
||||
"freq_shift": 0,
|
||||
"in_channels": 9,
|
||||
"layers_per_block": 3,
|
||||
"mid_block_only_cross_attention": None,
|
||||
"mid_block_scale_factor": 1,
|
||||
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
||||
"norm_eps": 1e-05,
|
||||
"norm_num_groups": 32,
|
||||
"num_class_embeds": None,
|
||||
"only_cross_attention": False,
|
||||
"out_channels": 8,
|
||||
"projection_class_embeddings_input_dim": None,
|
||||
"resnet_out_scale_factor": 1.0,
|
||||
"resnet_skip_time_act": False,
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"sample_size": 64,
|
||||
"up_block_types": (
|
||||
"time_cond_proj_dim": None,
|
||||
"time_embedding_act_fn": None,
|
||||
"time_embedding_dim": None,
|
||||
"time_embedding_type": "positional",
|
||||
"timestep_post_act": None,
|
||||
"up_block_types": [
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"SimpleCrossAttnUpBlock2D",
|
||||
"ResnetUpsampleBlock2D",
|
||||
),
|
||||
],
|
||||
"upcast_attention": False,
|
||||
"use_linear_projection": False,
|
||||
}
|
||||
@@ -381,10 +418,12 @@ def inpaint_unet_model_from_original_config():
|
||||
def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
num_head_channels = UNET_CONFIG["attention_head_dim"]
|
||||
num_head_channels = INPAINT_UNET_CONFIG["attention_head_dim"]
|
||||
|
||||
diffusers_checkpoint.update(unet_time_embeddings(checkpoint))
|
||||
diffusers_checkpoint.update(unet_conv_in(checkpoint))
|
||||
diffusers_checkpoint.update(unet_add_embedding(checkpoint))
|
||||
diffusers_checkpoint.update(unet_encoder_hid_proj(checkpoint))
|
||||
|
||||
# <original>.input_blocks -> <diffusers>.down_blocks
|
||||
|
||||
@@ -440,38 +479,6 @@ def inpaint_unet_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
|
||||
# done inpaint unet
|
||||
|
||||
# text proj
|
||||
|
||||
TEXT_PROJ_CONFIG = {}
|
||||
|
||||
|
||||
def text_proj_from_original_config():
|
||||
model = KandinskyTextProjModel(**TEXT_PROJ_CONFIG)
|
||||
return model
|
||||
|
||||
|
||||
# Note that the input checkpoint is the original text2img model checkpoint
|
||||
def text_proj_original_checkpoint_to_diffusers_checkpoint(checkpoint):
|
||||
diffusers_checkpoint = {
|
||||
# <original>.text_seq_proj.0 -> <diffusers>.encoder_hidden_states_proj
|
||||
"encoder_hidden_states_proj.weight": checkpoint["to_model_dim_n.weight"],
|
||||
"encoder_hidden_states_proj.bias": checkpoint["to_model_dim_n.bias"],
|
||||
# <original>.clip_tok_proj -> <diffusers>.clip_extra_context_tokens_proj
|
||||
"clip_extra_context_tokens_proj.weight": checkpoint["clip_to_seq.weight"],
|
||||
"clip_extra_context_tokens_proj.bias": checkpoint["clip_to_seq.bias"],
|
||||
# <original>.proj_n -> <diffusers>.embedding_proj
|
||||
"embedding_proj.weight": checkpoint["proj_n.weight"],
|
||||
"embedding_proj.bias": checkpoint["proj_n.bias"],
|
||||
# <original>.ln_model_n -> <diffusers>.embedding_norm
|
||||
"embedding_norm.weight": checkpoint["ln_model_n.weight"],
|
||||
"embedding_norm.bias": checkpoint["ln_model_n.bias"],
|
||||
# <original>.clip_emb -> <diffusers>.clip_image_embeddings_project_to_time_embeddings
|
||||
"clip_image_embeddings_project_to_time_embeddings.weight": checkpoint["img_layer.weight"],
|
||||
"clip_image_embeddings_project_to_time_embeddings.bias": checkpoint["img_layer.bias"],
|
||||
}
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# unet utils
|
||||
|
||||
@@ -506,6 +513,38 @@ def unet_conv_in(checkpoint):
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def unet_add_embedding(checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"add_embedding.text_norm.weight": checkpoint["ln_model_n.weight"],
|
||||
"add_embedding.text_norm.bias": checkpoint["ln_model_n.bias"],
|
||||
"add_embedding.text_proj.weight": checkpoint["proj_n.weight"],
|
||||
"add_embedding.text_proj.bias": checkpoint["proj_n.bias"],
|
||||
"add_embedding.image_proj.weight": checkpoint["img_layer.weight"],
|
||||
"add_embedding.image_proj.bias": checkpoint["img_layer.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def unet_encoder_hid_proj(checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"encoder_hid_proj.image_embeds.weight": checkpoint["clip_to_seq.weight"],
|
||||
"encoder_hid_proj.image_embeds.bias": checkpoint["clip_to_seq.bias"],
|
||||
"encoder_hid_proj.text_proj.weight": checkpoint["to_model_dim_n.weight"],
|
||||
"encoder_hid_proj.text_proj.bias": checkpoint["to_model_dim_n.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
# <original>.out.0 -> <diffusers>.conv_norm_out
|
||||
def unet_conv_norm_out(checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
@@ -857,25 +896,13 @@ def text2img(*, args, checkpoint_map_location):
|
||||
|
||||
unet_diffusers_checkpoint = unet_original_checkpoint_to_diffusers_checkpoint(unet_model, text2img_checkpoint)
|
||||
|
||||
# text proj interlude
|
||||
|
||||
# The original decoder implementation includes a set of parameters that are used
|
||||
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
|
||||
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
|
||||
# the parameters into the KandinskyTextProjModel class
|
||||
text_proj_model = text_proj_from_original_config()
|
||||
|
||||
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(text2img_checkpoint)
|
||||
|
||||
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
|
||||
|
||||
del text2img_checkpoint
|
||||
|
||||
load_checkpoint_to_model(unet_diffusers_checkpoint, unet_model, strict=True)
|
||||
|
||||
print("done loading text2img")
|
||||
|
||||
return unet_model, text_proj_model
|
||||
return unet_model
|
||||
|
||||
|
||||
def inpaint_text2img(*, args, checkpoint_map_location):
|
||||
@@ -891,25 +918,13 @@ def inpaint_text2img(*, args, checkpoint_map_location):
|
||||
inpaint_unet_model, inpaint_text2img_checkpoint
|
||||
)
|
||||
|
||||
# text proj interlude
|
||||
|
||||
# The original decoder implementation includes a set of parameters that are used
|
||||
# for creating the `encoder_hidden_states` which are what the U-net is conditioned
|
||||
# on. The diffusers conditional unet directly takes the encoder_hidden_states. We pull
|
||||
# the parameters into the KandinskyTextProjModel class
|
||||
text_proj_model = text_proj_from_original_config()
|
||||
|
||||
text_proj_checkpoint = text_proj_original_checkpoint_to_diffusers_checkpoint(inpaint_text2img_checkpoint)
|
||||
|
||||
load_checkpoint_to_model(text_proj_checkpoint, text_proj_model, strict=True)
|
||||
|
||||
del inpaint_text2img_checkpoint
|
||||
|
||||
load_checkpoint_to_model(inpaint_unet_diffusers_checkpoint, inpaint_unet_model, strict=True)
|
||||
|
||||
print("done loading inpaint text2img")
|
||||
|
||||
return inpaint_unet_model, text_proj_model
|
||||
return inpaint_unet_model
|
||||
|
||||
|
||||
# movq
|
||||
@@ -1384,15 +1399,11 @@ if __name__ == "__main__":
|
||||
prior_model = prior(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
prior_model.save_pretrained(args.dump_path)
|
||||
elif args.debug == "text2img":
|
||||
unet_model, text_proj_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
unet_model = text2img(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
unet_model.save_pretrained(f"{args.dump_path}/unet")
|
||||
text_proj_model.save_pretrained(f"{args.dump_path}/text_proj")
|
||||
elif args.debug == "inpaint_text2img":
|
||||
inpaint_unet_model, inpaint_text_proj_model = inpaint_text2img(
|
||||
args=args, checkpoint_map_location=checkpoint_map_location
|
||||
)
|
||||
inpaint_unet_model = inpaint_text2img(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
inpaint_unet_model.save_pretrained(f"{args.dump_path}/inpaint_unet")
|
||||
inpaint_text_proj_model.save_pretrained(f"{args.dump_path}/inpaint_text_proj")
|
||||
elif args.debug == "decoder":
|
||||
decoder = movq(args=args, checkpoint_map_location=checkpoint_map_location)
|
||||
decoder.save_pretrained(f"{args.dump_path}/decoder")
|
||||
|
||||
Reference in New Issue
Block a user