1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

conversion script

This commit is contained in:
yiyixuxu
2025-07-28 12:28:42 +02:00
parent 27ce75b984
commit 5709f7e04d

View File

@@ -764,7 +764,7 @@ def convert_vae_22():
"conv2.weight": "post_quant_conv.weight",
"conv2.bias": "post_quant_conv.bias",
}
# Process each key in the state dict
for key, value in old_state_dict.items():
# Handle middle block keys using the mapping
@@ -797,12 +797,12 @@ def convert_vae_22():
elif key.startswith("encoder.downsamples."):
# Change encoder.downsamples to encoder.down_blocks
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
# Handle residual blocks - change downsamples to resnets and rename components
if "residual" in new_key or "shortcut" in new_key:
# Change the second downsamples to resnets
new_key = new_key.replace(".downsamples.", ".resnets.")
# Rename residual components
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
@@ -820,7 +820,7 @@ def convert_vae_22():
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
# Handle resample blocks - change downsamples to downsampler and remove index
elif "resample" in new_key or "time_conv" in new_key:
# Change the second downsamples to downsampler and remove the index
@@ -831,19 +831,19 @@ def convert_vae_22():
# Remove the index (parts[4]) and change downsamples to downsampler
new_parts = parts[:3] + ["downsampler"] + parts[5:]
new_key = ".".join(new_parts)
new_state_dict[new_key] = value
# Handle decoder upsamples
elif key.startswith("decoder.upsamples."):
# Change decoder.upsamples to decoder.up_blocks
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
# Handle residual blocks - change upsamples to resnets and rename components
if "residual" in new_key or "shortcut" in new_key:
# Change the second upsamples to resnets
new_key = new_key.replace(".upsamples.", ".resnets.")
# Rename residual components
if ".residual.0.gamma" in new_key:
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
@@ -861,7 +861,7 @@ def convert_vae_22():
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
elif ".shortcut.bias" in new_key:
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
# Handle resample blocks - change upsamples to upsampler and remove index
elif "resample" in new_key or "time_conv" in new_key:
# Change the second upsamples to upsampler and remove the index
@@ -872,7 +872,7 @@ def convert_vae_22():
# Remove the index (parts[4]) and change upsamples to upsampler
new_parts = parts[:3] + ["upsampler"] + parts[5:]
new_key = ".".join(new_parts)
new_state_dict[new_key] = value
else:
# Keep other keys unchanged