From 5709f7e04d29569f5ee19f7aa60d185699c7e207 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Mon, 28 Jul 2025 12:28:42 +0200 Subject: [PATCH] conversion script --- scripts/convert_wan_to_diffusers.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/convert_wan_to_diffusers.py b/scripts/convert_wan_to_diffusers.py index 0a46ae80f0..e6a09e0d98 100644 --- a/scripts/convert_wan_to_diffusers.py +++ b/scripts/convert_wan_to_diffusers.py @@ -764,7 +764,7 @@ def convert_vae_22(): "conv2.weight": "post_quant_conv.weight", "conv2.bias": "post_quant_conv.bias", } - + # Process each key in the state dict for key, value in old_state_dict.items(): # Handle middle block keys using the mapping @@ -797,12 +797,12 @@ def convert_vae_22(): elif key.startswith("encoder.downsamples."): # Change encoder.downsamples to encoder.down_blocks new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.") - + # Handle residual blocks - change downsamples to resnets and rename components if "residual" in new_key or "shortcut" in new_key: # Change the second downsamples to resnets new_key = new_key.replace(".downsamples.", ".resnets.") - + # Rename residual components if ".residual.0.gamma" in new_key: new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") @@ -820,7 +820,7 @@ def convert_vae_22(): new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") elif ".shortcut.bias" in new_key: new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - + # Handle resample blocks - change downsamples to downsampler and remove index elif "resample" in new_key or "time_conv" in new_key: # Change the second downsamples to downsampler and remove the index @@ -831,19 +831,19 @@ def convert_vae_22(): # Remove the index (parts[4]) and change downsamples to downsampler new_parts = parts[:3] + ["downsampler"] + parts[5:] new_key = ".".join(new_parts) - + new_state_dict[new_key] = value # Handle decoder upsamples elif key.startswith("decoder.upsamples."): # Change decoder.upsamples to decoder.up_blocks new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.") - + # Handle residual blocks - change upsamples to resnets and rename components if "residual" in new_key or "shortcut" in new_key: # Change the second upsamples to resnets new_key = new_key.replace(".upsamples.", ".resnets.") - + # Rename residual components if ".residual.0.gamma" in new_key: new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma") @@ -861,7 +861,7 @@ def convert_vae_22(): new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight") elif ".shortcut.bias" in new_key: new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias") - + # Handle resample blocks - change upsamples to upsampler and remove index elif "resample" in new_key or "time_conv" in new_key: # Change the second upsamples to upsampler and remove the index @@ -872,7 +872,7 @@ def convert_vae_22(): # Remove the index (parts[4]) and change upsamples to upsampler new_parts = parts[:3] + ["upsampler"] + parts[5:] new_key = ".".join(new_parts) - + new_state_dict[new_key] = value else: # Keep other keys unchanged