diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index ef91e9e6c1..6d25cde071 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -1,6 +1,6 @@ import argparse import pathlib -from typing import Any, Dict +from typing import Any, Dict, Tuple import torch from accelerate import init_empty_weights @@ -14,6 +14,8 @@ from diffusers import ( WanImageToVideoPipeline, WanPipeline, WanTransformer3DModel, + WanVACEPipeline, + WanVACETransformer3DModel, ) @@ -59,7 +61,52 @@ TRANSFORMER_KEYS_RENAME_DICT = { "attn2.norm_k_img": "attn2.norm_added_k", } +VACE_TRANSFORMER_KEYS_RENAME_DICT = { + "time_embedding.0": "condition_embedder.time_embedder.linear_1", + "time_embedding.2": "condition_embedder.time_embedder.linear_2", + "text_embedding.0": "condition_embedder.text_embedder.linear_1", + "text_embedding.2": "condition_embedder.text_embedder.linear_2", + "time_projection.1": "condition_embedder.time_proj", + "head.modulation": "scale_shift_table", + "head.head": "proj_out", + "modulation": "scale_shift_table", + "ffn.0": "ffn.net.0.proj", + "ffn.2": "ffn.net.2", + # Hack to swap the layer names + # The original model calls the norms in following order: norm1, norm3, norm2 + # We convert it to: norm1, norm2, norm3 + "norm2": "norm__placeholder", + "norm3": "norm2", + "norm__placeholder": "norm3", + # # For the I2V model + # "img_emb.proj.0": "condition_embedder.image_embedder.norm1", + # "img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj", + # "img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2", + # "img_emb.proj.4": "condition_embedder.image_embedder.norm2", + # # for the FLF2V model + # "img_emb.emb_pos": "condition_embedder.image_embedder.pos_embed", + # Add attention component mappings + "self_attn.q": "attn1.to_q", + "self_attn.k": "attn1.to_k", + "self_attn.v": "attn1.to_v", + "self_attn.o": "attn1.to_out.0", + "self_attn.norm_q": "attn1.norm_q", + "self_attn.norm_k": "attn1.norm_k", + "cross_attn.q": "attn2.to_q", + "cross_attn.k": "attn2.to_k", + "cross_attn.v": "attn2.to_v", + "cross_attn.o": "attn2.to_out.0", + "cross_attn.norm_q": "attn2.norm_q", + "cross_attn.norm_k": "attn2.norm_k", + "attn2.to_k_img": "attn2.add_k_proj", + "attn2.to_v_img": "attn2.add_v_proj", + "attn2.norm_k_img": "attn2.norm_added_k", + "before_proj": "proj_in", + "after_proj": "proj_out", +} + TRANSFORMER_SPECIAL_KEYS_REMAP = {} +VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {} def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]: @@ -74,7 +121,7 @@ def load_sharded_safetensors(dir: pathlib.Path): return state_dict -def get_transformer_config(model_type: str) -> Dict[str, Any]: +def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]: if model_type == "Wan-T2V-1.3B": config = { "model_id": "StevenZhang/Wan2.1-T2V-1.3B-Diff", @@ -94,6 +141,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan-T2V-14B": config = { "model_id": "StevenZhang/Wan2.1-T2V-14B-Diff", @@ -113,6 +162,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan-I2V-14B-480p": config = { "model_id": "StevenZhang/Wan2.1-I2V-14B-480P-Diff", @@ -133,6 +184,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan-I2V-14B-720p": config = { "model_id": "StevenZhang/Wan2.1-I2V-14B-720P-Diff", @@ -153,6 +206,8 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "text_dim": 4096, }, } + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP elif model_type == "Wan-FLF2V-14B-720P": config = { "model_id": "ypyp/Wan2.1-FLF2V-14B-720P", # This is just a placeholder @@ -175,11 +230,60 @@ def get_transformer_config(model_type: str) -> Dict[str, Any]: "pos_embed_seq_len": 257 * 2, }, } - return config + RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan-VACE-1.3B": + config = { + "model_id": "Wan-AI/Wan2.1-VACE-1.3B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 8960, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 12, + "num_layers": 30, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28], + "vace_in_channels": 96, + }, + } + RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP + elif model_type == "Wan-VACE-14B": + config = { + "model_id": "Wan-AI/Wan2.1-VACE-14B", + "diffusers_config": { + "added_kv_proj_dim": None, + "attention_head_dim": 128, + "cross_attn_norm": True, + "eps": 1e-06, + "ffn_dim": 13824, + "freq_dim": 256, + "in_channels": 16, + "num_attention_heads": 40, + "num_layers": 40, + "out_channels": 16, + "patch_size": [1, 2, 2], + "qk_norm": "rms_norm_across_heads", + "text_dim": 4096, + "vace_layers": [0, 5, 10, 15, 20, 25, 30, 35], + "vace_in_channels": 96, + }, + } + RENAME_DICT = VACE_TRANSFORMER_KEYS_RENAME_DICT + SPECIAL_KEYS_REMAP = VACE_TRANSFORMER_SPECIAL_KEYS_REMAP + return config, RENAME_DICT, SPECIAL_KEYS_REMAP def convert_transformer(model_type: str): - config = get_transformer_config(model_type) + config, RENAME_DICT, SPECIAL_KEYS_REMAP = get_transformer_config(model_type) + diffusers_config = config["diffusers_config"] model_id = config["model_id"] model_dir = pathlib.Path(snapshot_download(model_id, repo_type="model")) @@ -187,16 +291,19 @@ def convert_transformer(model_type: str): original_state_dict = load_sharded_safetensors(model_dir) with init_empty_weights(): - transformer = WanTransformer3DModel.from_config(diffusers_config) + if "VACE" not in model_type: + transformer = WanTransformer3DModel.from_config(diffusers_config) + else: + transformer = WanVACETransformer3DModel.from_config(diffusers_config) for key in list(original_state_dict.keys()): new_key = key[:] - for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + for replace_key, rename_key in RENAME_DICT.items(): new_key = new_key.replace(replace_key, rename_key) update_state_dict_(original_state_dict, key, new_key) for key in list(original_state_dict.keys()): - for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items(): + for special_key, handler_fn_inplace in SPECIAL_KEYS_REMAP.items(): if special_key not in key: continue handler_fn_inplace(key, original_state_dict) @@ -412,7 +519,7 @@ def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--model_type", type=str, default=None) parser.add_argument("--output_path", type=str, required=True) - parser.add_argument("--dtype", default="fp32") + parser.add_argument("--dtype", default="fp32", choices=["fp32", "fp16", "bf16", "none"]) return parser.parse_args() @@ -426,18 +533,20 @@ DTYPE_MAPPING = { if __name__ == "__main__": args = get_args() - transformer = None - dtype = DTYPE_MAPPING[args.dtype] - - transformer = convert_transformer(args.model_type).to(dtype=dtype) + transformer = convert_transformer(args.model_type) vae = convert_vae() - text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl") + text_encoder = UMT5EncoderModel.from_pretrained("google/umt5-xxl", torch_dtype=torch.bfloat16) tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl") flow_shift = 16.0 if "FLF2V" in args.model_type else 3.0 scheduler = UniPCMultistepScheduler( prediction_type="flow_prediction", use_flow_sigmas=True, num_train_timesteps=1000, flow_shift=flow_shift ) + # If user has specified "none", we keep the original dtypes of the state dict without any conversion + if args.dtype != "none": + dtype = DTYPE_MAPPING[args.dtype] + transformer.to(dtype) + if "I2V" in args.model_type or "FLF2V" in args.model_type: image_encoder = CLIPVisionModelWithProjection.from_pretrained( "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16 @@ -452,6 +561,14 @@ if __name__ == "__main__": image_encoder=image_encoder, image_processor=image_processor, ) + elif "VACE" in args.model_type: + pipe = WanVACEPipeline( + transformer=transformer, + text_encoder=text_encoder, + tokenizer=tokenizer, + vae=vae, + scheduler=scheduler, + ) else: pipe = WanPipeline( transformer=transformer,