mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
--------- Co-authored-by: Tolga Cangöz <mtcangoz@gmail.com> Co-authored-by: Tolga Cangöz <46008593+tolgacangoz@users.noreply.github.com>
This commit is contained in:
@@ -6,11 +6,20 @@ import torch
|
||||
from accelerate import init_empty_weights
|
||||
from huggingface_hub import hf_hub_download, snapshot_download
|
||||
from safetensors.torch import load_file
|
||||
from transformers import AutoProcessor, AutoTokenizer, CLIPVisionModelWithProjection, UMT5EncoderModel
|
||||
from transformers import (
|
||||
AutoProcessor,
|
||||
AutoTokenizer,
|
||||
CLIPImageProcessor,
|
||||
CLIPVisionModel,
|
||||
CLIPVisionModelWithProjection,
|
||||
UMT5EncoderModel,
|
||||
)
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKLWan,
|
||||
UniPCMultistepScheduler,
|
||||
WanAnimatePipeline,
|
||||
WanAnimateTransformer3DModel,
|
||||
WanImageToVideoPipeline,
|
||||
WanPipeline,
|
||||
WanTransformer3DModel,
|
||||
@@ -105,8 +114,203 @@ VACE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
ANIMATE_TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
"text_embedding.0": "condition_embedder.text_embedder.linear_1",
|
||||
"text_embedding.2": "condition_embedder.text_embedder.linear_2",
|
||||
"time_projection.1": "condition_embedder.time_proj",
|
||||
"head.modulation": "scale_shift_table",
|
||||
"head.head": "proj_out",
|
||||
"modulation": "scale_shift_table",
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# Add attention component mappings
|
||||
"self_attn.q": "attn1.to_q",
|
||||
"self_attn.k": "attn1.to_k",
|
||||
"self_attn.v": "attn1.to_v",
|
||||
"self_attn.o": "attn1.to_out.0",
|
||||
"self_attn.norm_q": "attn1.norm_q",
|
||||
"self_attn.norm_k": "attn1.norm_k",
|
||||
"cross_attn.q": "attn2.to_q",
|
||||
"cross_attn.k": "attn2.to_k",
|
||||
"cross_attn.v": "attn2.to_v",
|
||||
"cross_attn.o": "attn2.to_out.0",
|
||||
"cross_attn.norm_q": "attn2.norm_q",
|
||||
"cross_attn.norm_k": "attn2.norm_k",
|
||||
"cross_attn.k_img": "attn2.to_k_img",
|
||||
"cross_attn.v_img": "attn2.to_v_img",
|
||||
"cross_attn.norm_k_img": "attn2.norm_k_img",
|
||||
# After cross_attn -> attn2 rename, we need to rename the img keys
|
||||
"attn2.to_k_img": "attn2.add_k_proj",
|
||||
"attn2.to_v_img": "attn2.add_v_proj",
|
||||
"attn2.norm_k_img": "attn2.norm_added_k",
|
||||
# Wan Animate-specific mappings (motion encoder, face encoder, face adapter)
|
||||
# Motion encoder mappings
|
||||
# The name mapping is complicated for the convolutional part so we handle that in its own function
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
# Face encoder mappings - CausalConv1d has a .conv submodule that we need to flatten
|
||||
"face_encoder.conv1_local.conv": "face_encoder.conv1_local",
|
||||
"face_encoder.conv2.conv": "face_encoder.conv2",
|
||||
"face_encoder.conv3.conv": "face_encoder.conv3",
|
||||
# Face adapter mappings are handled in a separate function
|
||||
}
|
||||
|
||||
|
||||
# TODO: Verify this and simplify if possible.
|
||||
def convert_animate_motion_encoder_weights(key: str, state_dict: Dict[str, Any], final_conv_idx: int = 8) -> None:
|
||||
"""
|
||||
Convert all motion encoder weights for Animate model.
|
||||
|
||||
In the original model:
|
||||
- All Linear layers in fc use EqualLinear
|
||||
- All Conv2d layers in convs use EqualConv2d (except blur_conv which is initialized separately)
|
||||
- Blur kernels are stored as buffers in Sequential modules
|
||||
- ConvLayer is nn.Sequential with indices: [Blur (optional), EqualConv2d, FusedLeakyReLU (optional)]
|
||||
|
||||
Conversion strategy:
|
||||
1. Drop .kernel buffers (blur kernels)
|
||||
2. Rename sequential indices to named components (e.g., 0 -> conv2d, 1 -> bias_leaky_relu)
|
||||
"""
|
||||
# Skip if not a weight, bias, or kernel
|
||||
if ".weight" not in key and ".bias" not in key and ".kernel" not in key:
|
||||
return
|
||||
|
||||
# Handle Blur kernel buffers from original implementation.
|
||||
# After renaming, these appear under: motion_encoder.res_blocks.*.conv{2,skip}.blur_kernel
|
||||
# Diffusers constructs blur kernels as a non-persistent buffer so we must drop these keys
|
||||
if ".kernel" in key and "motion_encoder" in key:
|
||||
# Remove unexpected blur kernel buffers to avoid strict load errors
|
||||
state_dict.pop(key, None)
|
||||
return
|
||||
|
||||
# Rename Sequential indices to named components in ConvLayer and ResBlock
|
||||
if ".enc.net_app.convs." in key and (".weight" in key or ".bias" in key):
|
||||
parts = key.split(".")
|
||||
|
||||
# Find the sequential index (digit) after convs or after conv1/conv2/skip
|
||||
# Examples:
|
||||
# - enc.net_app.convs.0.0.weight -> conv_in.weight (initial conv layer weight)
|
||||
# - enc.net_app.convs.0.1.bias -> conv_in.act_fn.bias (initial conv layer bias)
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.0.weight -> res_blocks.{(n-1):0-6}.conv1.weight (conv1 weight)
|
||||
# - e.g. enc.net_app.convs.1.conv1.0.weight -> res_blocks.0.conv1.weight
|
||||
# - enc.net_app.convs.{n:1-7}.conv1.1.bias -> res_blocks.{(n-1):0-6}.conv1.act_fn.bias (conv1 bias)
|
||||
# - e.g. enc.net_app.convs.1.conv1.1.bias -> res_blocks.0.conv1.act_fn.bias
|
||||
# - enc.net_app.convs.{n:1-7}.conv2.1.weight -> res_blocks.{(n-1):0-6}.conv2.weight (conv2 weight)
|
||||
# - enc.net_app.convs.1.conv2.2.bias -> res_blocks.0.conv2.act_fn.bias (conv2 bias)
|
||||
# - enc.net_app.convs.{n:1-7}.skip.1.weight -> res_blocks.{(n-1):0-6}.conv_skip.weight (skip conv weight)
|
||||
# - enc.net_app.convs.8 -> conv_out (final conv layer)
|
||||
|
||||
convs_idx = parts.index("convs") if "convs" in parts else -1
|
||||
if convs_idx >= 0 and len(parts) - convs_idx >= 2:
|
||||
bias = False
|
||||
# The nn.Sequential index will always follow convs
|
||||
sequential_idx = int(parts[convs_idx + 1])
|
||||
if sequential_idx == 0:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_in.weight"
|
||||
elif key.endswith(".bias"):
|
||||
new_key = "motion_encoder.conv_in.act_fn.bias"
|
||||
bias = True
|
||||
elif sequential_idx == final_conv_idx:
|
||||
if key.endswith(".weight"):
|
||||
new_key = "motion_encoder.conv_out.weight"
|
||||
else:
|
||||
# Intermediate .convs. layers, which get mapped to .res_blocks.
|
||||
prefix = "motion_encoder.res_blocks."
|
||||
|
||||
layer_name = parts[convs_idx + 2]
|
||||
if layer_name == "skip":
|
||||
layer_name = "conv_skip"
|
||||
|
||||
if key.endswith(".weight"):
|
||||
param_name = "weight"
|
||||
elif key.endswith(".bias"):
|
||||
param_name = "act_fn.bias"
|
||||
bias = True
|
||||
|
||||
suffix_parts = [str(sequential_idx - 1), layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
|
||||
param = state_dict.pop(key)
|
||||
if bias:
|
||||
param = param.squeeze()
|
||||
state_dict[new_key] = param
|
||||
return
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
def convert_animate_face_adapter_weights(key: str, state_dict: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Convert face adapter weights for the Animate model.
|
||||
|
||||
The original model uses a fused KV projection but the diffusers models uses separate K and V projections.
|
||||
"""
|
||||
# Skip if not a weight or bias
|
||||
if ".weight" not in key and ".bias" not in key:
|
||||
return
|
||||
|
||||
prefix = "face_adapter."
|
||||
if ".fuser_blocks." in key:
|
||||
parts = key.split(".")
|
||||
|
||||
module_list_idx = parts.index("fuser_blocks") if "fuser_blocks" in parts else -1
|
||||
if module_list_idx >= 0 and (len(parts) - 1) - module_list_idx == 3:
|
||||
block_idx = parts[module_list_idx + 1]
|
||||
layer_name = parts[module_list_idx + 2]
|
||||
param_name = parts[module_list_idx + 3]
|
||||
|
||||
if layer_name == "linear1_kv":
|
||||
layer_name_k = "to_k"
|
||||
layer_name_v = "to_v"
|
||||
|
||||
suffix_k = ".".join([block_idx, layer_name_k, param_name])
|
||||
suffix_v = ".".join([block_idx, layer_name_v, param_name])
|
||||
new_key_k = prefix + suffix_k
|
||||
new_key_v = prefix + suffix_v
|
||||
|
||||
kv_proj = state_dict.pop(key)
|
||||
k_proj, v_proj = torch.chunk(kv_proj, 2, dim=0)
|
||||
state_dict[new_key_k] = k_proj
|
||||
state_dict[new_key_v] = v_proj
|
||||
return
|
||||
else:
|
||||
if layer_name == "q_norm":
|
||||
new_layer_name = "norm_q"
|
||||
elif layer_name == "k_norm":
|
||||
new_layer_name = "norm_k"
|
||||
elif layer_name == "linear1_q":
|
||||
new_layer_name = "to_q"
|
||||
elif layer_name == "linear2":
|
||||
new_layer_name = "to_out"
|
||||
|
||||
suffix_parts = [block_idx, new_layer_name, param_name]
|
||||
suffix = ".".join(suffix_parts)
|
||||
new_key = prefix + suffix
|
||||
state_dict[new_key] = state_dict.pop(key)
|
||||
return
|
||||
return
|
||||
|
||||
|
||||
TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
VACE_TRANSFORMER_SPECIAL_KEYS_REMAP = {}
|
||||
ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP = {
|
||||
"motion_encoder": convert_animate_motion_encoder_weights,
|
||||
"face_adapter": convert_animate_face_adapter_weights,
|
||||
}
|
||||
|
||||
|
||||
def update_state_dict_(state_dict: Dict[str, Any], old_key: str, new_key: str) -> Dict[str, Any]:
|
||||
@@ -364,6 +568,37 @@ def get_transformer_config(model_type: str) -> Tuple[Dict[str, Any], ...]:
|
||||
}
|
||||
RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
elif model_type == "Wan2.2-Animate-14B":
|
||||
config = {
|
||||
"model_id": "Wan-AI/Wan2.2-Animate-14B",
|
||||
"diffusers_config": {
|
||||
"image_dim": 1280,
|
||||
"added_kv_proj_dim": 5120,
|
||||
"attention_head_dim": 128,
|
||||
"cross_attn_norm": True,
|
||||
"eps": 1e-06,
|
||||
"ffn_dim": 13824,
|
||||
"freq_dim": 256,
|
||||
"in_channels": 36,
|
||||
"num_attention_heads": 40,
|
||||
"num_layers": 40,
|
||||
"out_channels": 16,
|
||||
"patch_size": (1, 2, 2),
|
||||
"qk_norm": "rms_norm_across_heads",
|
||||
"text_dim": 4096,
|
||||
"rope_max_seq_len": 1024,
|
||||
"pos_embed_seq_len": None,
|
||||
"motion_encoder_size": 512, # Start of Wan Animate-specific configs
|
||||
"motion_style_dim": 512,
|
||||
"motion_dim": 20,
|
||||
"motion_encoder_dim": 512,
|
||||
"face_encoder_hidden_dim": 1024,
|
||||
"face_encoder_num_heads": 4,
|
||||
"inject_face_latents_blocks": 5,
|
||||
},
|
||||
}
|
||||
RENAME_DICT = ANIMATE_TRANSFORMER_KEYS_RENAME_DICT
|
||||
SPECIAL_KEYS_REMAP = ANIMATE_TRANSFORMER_SPECIAL_KEYS_REMAP
|
||||
return config, RENAME_DICT, SPECIAL_KEYS_REMAP
|
||||
|
||||
|
||||
@@ -380,10 +615,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
original_state_dict = load_sharded_safetensors(model_dir)
|
||||
|
||||
with init_empty_weights():
|
||||
if "VACE" not in model_type:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
if "Animate" in model_type:
|
||||
transformer = WanAnimateTransformer3DModel.from_config(diffusers_config)
|
||||
elif "VACE" in model_type:
|
||||
transformer = WanVACETransformer3DModel.from_config(diffusers_config)
|
||||
else:
|
||||
transformer = WanTransformer3DModel.from_config(diffusers_config)
|
||||
|
||||
for key in list(original_state_dict.keys()):
|
||||
new_key = key[:]
|
||||
@@ -397,7 +634,12 @@ def convert_transformer(model_type: str, stage: str = None):
|
||||
continue
|
||||
handler_fn_inplace(key, original_state_dict)
|
||||
|
||||
# Load state dict into the meta model, which will materialize the tensors
|
||||
transformer.load_state_dict(original_state_dict, strict=True, assign=True)
|
||||
|
||||
# Move to CPU to ensure all tensors are materialized
|
||||
transformer = transformer.to("cpu")
|
||||
|
||||
return transformer
|
||||
|
||||
|
||||
@@ -926,7 +1168,7 @@ DTYPE_MAPPING = {
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type:
|
||||
if "Wan2.2" in args.model_type and "TI2V" not in args.model_type and "Animate" not in args.model_type:
|
||||
transformer = convert_transformer(args.model_type, stage="high_noise_model")
|
||||
transformer_2 = convert_transformer(args.model_type, stage="low_noise_model")
|
||||
else:
|
||||
@@ -942,7 +1184,7 @@ if __name__ == "__main__":
|
||||
tokenizer = AutoTokenizer.from_pretrained("google/umt5-xxl")
|
||||
if "FLF2V" in args.model_type:
|
||||
flow_shift = 16.0
|
||||
elif "TI2V" in args.model_type:
|
||||
elif "TI2V" in args.model_type or "Animate" in args.model_type:
|
||||
flow_shift = 5.0
|
||||
else:
|
||||
flow_shift = 3.0
|
||||
@@ -954,6 +1196,8 @@ if __name__ == "__main__":
|
||||
if args.dtype != "none":
|
||||
dtype = DTYPE_MAPPING[args.dtype]
|
||||
transformer.to(dtype)
|
||||
if transformer_2 is not None:
|
||||
transformer_2.to(dtype)
|
||||
|
||||
if "Wan2.2" and "I2V" in args.model_type and "TI2V" not in args.model_type:
|
||||
pipe = WanImageToVideoPipeline(
|
||||
@@ -1016,6 +1260,21 @@ if __name__ == "__main__":
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
elif "Animate" in args.model_type:
|
||||
image_encoder = CLIPVisionModel.from_pretrained(
|
||||
"laion/CLIP-ViT-H-14-laion2B-s32B-b79K", torch_dtype=torch.bfloat16
|
||||
)
|
||||
image_processor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
||||
|
||||
pipe = WanAnimatePipeline(
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
vae=vae,
|
||||
scheduler=scheduler,
|
||||
image_encoder=image_encoder,
|
||||
image_processor=image_processor,
|
||||
)
|
||||
else:
|
||||
pipe = WanPipeline(
|
||||
transformer=transformer,
|
||||
|
||||
Reference in New Issue
Block a user