mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
add conversion script
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import argparse
|
||||
import pathlib
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, Tuple
|
||||
|
||||
import torch
|
||||
from accelerate import init_empty_weights
|
||||
@@ -14,6 +14,8 @@ from diffusers import (
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
WanVACEPipeline,
|
||||
WanVACETransformer3DModel,
|
||||
)
|
||||
|
||||
|
||||
@@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
}
|
||||
|
||||
VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# # For the I2V model
|
||||
# "img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
# "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
# "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
# "img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# # for the FLF2V model
|
||||
# "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path):
|
||||
return state_dict
|
||||
|
||||
|
||||
def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
if model_type == "Wan-T2V-1.3B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff",
|
||||
@@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-T2V-14B":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-T2V-14B-Diff",
|
||||
@@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-480p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff",
|
||||
@@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-I2V-14B-720p":
|
||||
config = {
|
||||
"model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff",
|
||||
@@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"text_dim": 4096,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-FLF2V-14B-720P":
|
||||
config = {
|
||||
"model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder
|
||||
@@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]:
|
||||
"pos_embed_seq_len": 257 * 2,
|
||||
},
|
||||
}
|
||||
return config
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-1.3B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-1.3B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 8960,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 12,
|
||||
"num_layers": 30,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan-VACE-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.1-VACE-14B",
|
||||
"diffusers_config": {
|
||||
"added_kv_proj_dim": None,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 16,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": [1, 2, 2],
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"vace_layers": [0, 5, 10, 15, 20, 25, 30, 35],
|
||||
"vace_in_channels": 96,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
def convert_transformer(model_type: str):
|
||||
config = get_transformer_config(model_type)
|
||||
config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type)
|
||||
|
||||
diffusers_config = config["diffusers_config"]
|
||||
model_id = config["model_id"]
|
||||
model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model"))
|
||||
@@ -187,16 +291,19 @@ def convert_transformer(model_type: str):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
for replace_key, rename_key in RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
update_state_dict_(original_state_dict, key, new_key)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
||||
for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items():
|
||||
if special_key not in key:
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
@@ -412,7 +519,7 @@ def get_args():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model_type", type=str, default=None)
|
||||
parser.add_argument("--output_path", type=str, required=True)
|
||||
parser.add_argument("--dtype", default="fp32")
|
||||
parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"])
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
@@ -426,18 +533,20 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
transformer = None
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
|
||||
transformer = convert_transformer(args.model_type).to(dtype=dtype)
|
||||
transformer = convert_transformer(args.model_type)
|
||||
vae = convert_vae()
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl")
|
||||
text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16)
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0
|
||||
scheduler = UniPCMultistepScheduler(
|
||||
prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift
|
||||
)
|
||||
|
||||
# If user has specified "none", we keep the original dtypes of the state dict without any conversion
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
|
||||
if "I2V" in args.model_type or "FLF2V" in args.model_type:
|
||||
image_encoder = CLIPVisionModelWithProjection.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
@@ -452,6 +561,14 @@ if __name__ == "__main__":
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
elif "VACE" in args.model_type:
|
||||
pipe = WanVACEPipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
Reference in New Issue
Block a user