mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[From Single File] support from_single_file method for WanAnimateTransformer3DModel (#12691)
* Add `WanAnimateTransformer3DModel` to `SINGLE_FILE_LOADABLE_CLASSES` * Fixed dtype mismatch when loading a single file * Fixed a bug that results in white noise for generation * Update dtype check for time embedder - caused white noise output * Improve code readability * Optimize dtype handling Removed unnecessary dtype conversions for timestep and weight. * Apply style fixes * Refactor time embedding dtype handling Adjust time embedding type conversion for compatibility. * Apply style fixes * Modify comment for WanTimeTextImageEmbedding class --------- Co-authored-by: Sam Edwards <sam.edwards1976@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
This commit is contained in:
@@ -152,6 +152,10 @@ SINGLE_FILE_LOADABLE_CLASSES = {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"WanAnimateTransformer3DModel": {
|
||||
"checkpoint_mapping_fn": convert_wan_transformer_to_diffusers,
|
||||
"default_subfolder": "transformer",
|
||||
},
|
||||
"AutoencoderKLWan": {
|
||||
"checkpoint_mapping_fn": convert_wan_vae_to_diffusers,
|
||||
"default_subfolder": "vae",
|
||||
|
||||
@@ -136,6 +136,7 @@ CHECKPOINT_KEY_NAMES = {
|
||||
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
||||
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
||||
"wan_vace": "vace_blocks.0.after_proj.bias",
|
||||
"wan_animate": "motion_encoder.dec.direction.weight",
|
||||
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
||||
"cosmos-1.0": [
|
||||
"net.x_embedder.proj.1.weight",
|
||||
@@ -219,6 +220,7 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
||||
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
||||
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
||||
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
||||
"wan-animate-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.2-Animate-14B-Diffusers"},
|
||||
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
||||
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
||||
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
||||
@@ -759,6 +761,9 @@ def infer_diffusers_model_type(checkpoint):
|
||||
elif checkpoint[target_key].shape[0] == 5120:
|
||||
model_type = "wan-vace-14B"
|
||||
|
||||
if CHECKPOINT_KEY_NAMES["wan_animate"] in checkpoint:
|
||||
model_type = "wan-animate-14B"
|
||||
|
||||
elif checkpoint[target_key].shape[0] == 1536:
|
||||
model_type = "wan-t2v-1.3B"
|
||||
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
||||
@@ -3154,13 +3159,64 @@ def convert_sana_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
|
||||
|
||||
def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
def generate_motion_encoder_mappings():
|
||||
mappings = {
|
||||
"motion_encoder.dec.direction.weight": "motion_encoder.motion_synthesis_weight",
|
||||
"motion_encoder.enc.net_app.convs.0.0.weight": "motion_encoder.conv_in.weight",
|
||||
"motion_encoder.enc.net_app.convs.0.1.bias": "motion_encoder.conv_in.act_fn.bias",
|
||||
"motion_encoder.enc.net_app.convs.8.weight": "motion_encoder.conv_out.weight",
|
||||
"motion_encoder.enc.fc": "motion_encoder.motion_network",
|
||||
}
|
||||
|
||||
for i in range(7):
|
||||
conv_idx = i + 1
|
||||
mappings.update(
|
||||
{
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.0.weight": f"motion_encoder.res_blocks.{i}.conv1.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv1.1.bias": f"motion_encoder.res_blocks.{i}.conv1.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.1.weight": f"motion_encoder.res_blocks.{i}.conv2.weight",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.conv2.2.bias": f"motion_encoder.res_blocks.{i}.conv2.act_fn.bias",
|
||||
f"motion_encoder.enc.net_app.convs.{conv_idx}.skip.1.weight": f"motion_encoder.res_blocks.{i}.conv_skip.weight",
|
||||
}
|
||||
)
|
||||
|
||||
return mappings
|
||||
|
||||
def generate_face_adapter_mappings():
|
||||
return {
|
||||
"face_adapter.fuser_blocks": "face_adapter",
|
||||
".k_norm.": ".norm_k.",
|
||||
".q_norm.": ".norm_q.",
|
||||
".linear1_q.": ".to_q.",
|
||||
".linear2.": ".to_out.",
|
||||
"conv1_local.conv": "conv1_local",
|
||||
"conv2.conv": "conv2",
|
||||
"conv3.conv": "conv3",
|
||||
}
|
||||
|
||||
def split_tensor_handler(key, state_dict, split_pattern, target_keys):
|
||||
tensor = state_dict.pop(key)
|
||||
split_idx = tensor.shape[0] // 2
|
||||
|
||||
new_key_1 = key.replace(split_pattern, target_keys[0])
|
||||
new_key_2 = key.replace(split_pattern, target_keys[1])
|
||||
|
||||
state_dict[new_key_1] = tensor[:split_idx]
|
||||
state_dict[new_key_2] = tensor[split_idx:]
|
||||
|
||||
def reshape_bias_handler(key, state_dict):
|
||||
if "motion_encoder.enc.net_app.convs." in key and ".bias" in key:
|
||||
state_dict[key] = state_dict[key][0, :, 0, 0]
|
||||
|
||||
converted_state_dict = {}
|
||||
|
||||
# Strip model.diffusion_model prefix
|
||||
keys = list(checkpoint.keys())
|
||||
for k in keys:
|
||||
if "model.diffusion_model." in k:
|
||||
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
||||
|
||||
# Base transformer mappings
|
||||
TRANSFORMER_KEYS_RENAME_DICT = {
|
||||
"time_embedding.0": "condition_embedder.time_embedder.linear_1",
|
||||
"time_embedding.2": "condition_embedder.time_embedder.linear_2",
|
||||
@@ -3182,28 +3238,43 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
||||
"ffn.0": "ffn.net.0.proj",
|
||||
"ffn.2": "ffn.net.2",
|
||||
# Hack to swap the layer names
|
||||
# The original model calls the norms in following order: norm1, norm3, norm2
|
||||
# We convert it to: norm1, norm2, norm3
|
||||
"norm2": "norm__placeholder",
|
||||
"norm3": "norm2",
|
||||
"norm__placeholder": "norm3",
|
||||
# For the I2V model
|
||||
# I2V model
|
||||
"img_emb.proj.0": "condition_embedder.image_embedder.norm1",
|
||||
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
||||
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
||||
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
||||
# For the VACE model
|
||||
# VACE model
|
||||
"before_proj": "proj_in",
|
||||
"after_proj": "proj_out",
|
||||
}
|
||||
|
||||
SPECIAL_KEYS_HANDLERS = {}
|
||||
if any("face_adapter" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_face_adapter_mappings())
|
||||
SPECIAL_KEYS_HANDLERS[".linear1_kv."] = (split_tensor_handler, [".to_k.", ".to_v."])
|
||||
|
||||
if any("motion_encoder" in k for k in checkpoint.keys()):
|
||||
TRANSFORMER_KEYS_RENAME_DICT.update(generate_motion_encoder_mappings())
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key[:]
|
||||
reshape_bias_handler(key, checkpoint)
|
||||
|
||||
for key in list(checkpoint.keys()):
|
||||
new_key = key
|
||||
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
||||
new_key = new_key.replace(replace_key, rename_key)
|
||||
|
||||
converted_state_dict[new_key] = checkpoint.pop(key)
|
||||
|
||||
for key in list(converted_state_dict.keys()):
|
||||
for pattern, (handler_fn, target_keys) in SPECIAL_KEYS_HANDLERS.items():
|
||||
if pattern not in key:
|
||||
continue
|
||||
handler_fn(key, converted_state_dict, pattern, target_keys)
|
||||
break
|
||||
|
||||
return converted_state_dict
|
||||
|
||||
|
||||
|
||||
@@ -166,9 +166,11 @@ class MotionConv2d(nn.Module):
|
||||
# NOTE: the original implementation uses a 2D upfirdn operation with the upsampling and downsampling rates
|
||||
# set to 1, which should be equivalent to a 2D convolution
|
||||
expanded_kernel = self.blur_kernel[None, None, :, :].expand(self.in_channels, 1, -1, -1)
|
||||
x = x.to(expanded_kernel.dtype)
|
||||
x = F.conv2d(x, expanded_kernel, padding=self.blur_padding, groups=self.in_channels)
|
||||
|
||||
# Main Conv2D with scaling
|
||||
x = x.to(self.weight.dtype)
|
||||
x = F.conv2d(x, self.weight * self.scale, bias=self.bias, stride=self.stride, padding=self.padding)
|
||||
|
||||
# Activation with fused bias, if using
|
||||
@@ -338,8 +340,7 @@ class WanAnimateMotionEncoder(nn.Module):
|
||||
weight = self.motion_synthesis_weight + 1e-8
|
||||
# Upcast the QR orthogonalization operation to FP32
|
||||
original_motion_dtype = motion_feat.dtype
|
||||
motion_feat = motion_feat.to(torch.float32)
|
||||
weight = weight.to(torch.float32)
|
||||
motion_feat = motion_feat.to(weight.dtype)
|
||||
|
||||
Q = torch.linalg.qr(weight)[0].to(device=motion_feat.device)
|
||||
|
||||
@@ -769,7 +770,7 @@ class WanImageEmbedding(torch.nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
# Modified from diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding
|
||||
class WanTimeTextImageEmbedding(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -803,10 +804,12 @@ class WanTimeTextImageEmbedding(nn.Module):
|
||||
if timestep_seq_len is not None:
|
||||
timestep = timestep.unflatten(0, (-1, timestep_seq_len))
|
||||
|
||||
time_embedder_dtype = next(iter(self.time_embedder.parameters())).dtype
|
||||
if timestep.dtype != time_embedder_dtype and time_embedder_dtype != torch.int8:
|
||||
timestep = timestep.to(time_embedder_dtype)
|
||||
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
|
||||
if self.time_embedder.linear_1.weight.dtype.is_floating_point:
|
||||
time_embedder_dtype = self.time_embedder.linear_1.weight.dtype
|
||||
else:
|
||||
time_embedder_dtype = encoder_hidden_states.dtype
|
||||
|
||||
temb = self.time_embedder(timestep.to(time_embedder_dtype)).type_as(encoder_hidden_states)
|
||||
timestep_proj = self.time_proj(self.act_fn(temb))
|
||||
|
||||
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user