mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
conversion script
This commit is contained in:
@@ -764,7 +764,7 @@ def convert_vae_22():
|
||||
"conv2.weight": "post_quant_conv.weight",
|
||||
"conv2.bias": "post_quant_conv.bias",
|
||||
}
|
||||
|
||||
|
||||
# Process each key in the state dict
|
||||
for key, value in old_state_dict.items():
|
||||
# Handle middle block keys using the mapping
|
||||
@@ -797,12 +797,12 @@ def convert_vae_22():
|
||||
elif key.startswith("encoder.downsamples."):
|
||||
# Change encoder.downsamples to encoder.down_blocks
|
||||
new_key = key.replace("encoder.downsamples.", "encoder.down_blocks.")
|
||||
|
||||
|
||||
# Handle residual blocks - change downsamples to resnets and rename components
|
||||
if "residual" in new_key or "shortcut" in new_key:
|
||||
# Change the second downsamples to resnets
|
||||
new_key = new_key.replace(".downsamples.", ".resnets.")
|
||||
|
||||
|
||||
# Rename residual components
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
@@ -820,7 +820,7 @@ def convert_vae_22():
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
|
||||
|
||||
# Handle resample blocks - change downsamples to downsampler and remove index
|
||||
elif "resample" in new_key or "time_conv" in new_key:
|
||||
# Change the second downsamples to downsampler and remove the index
|
||||
@@ -831,19 +831,19 @@ def convert_vae_22():
|
||||
# Remove the index (parts[4]) and change downsamples to downsampler
|
||||
new_parts = parts[:3] + ["downsampler"] + parts[5:]
|
||||
new_key = ".".join(new_parts)
|
||||
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
# Handle decoder upsamples
|
||||
elif key.startswith("decoder.upsamples."):
|
||||
# Change decoder.upsamples to decoder.up_blocks
|
||||
new_key = key.replace("decoder.upsamples.", "decoder.up_blocks.")
|
||||
|
||||
|
||||
# Handle residual blocks - change upsamples to resnets and rename components
|
||||
if "residual" in new_key or "shortcut" in new_key:
|
||||
# Change the second upsamples to resnets
|
||||
new_key = new_key.replace(".upsamples.", ".resnets.")
|
||||
|
||||
|
||||
# Rename residual components
|
||||
if ".residual.0.gamma" in new_key:
|
||||
new_key = new_key.replace(".residual.0.gamma", ".norm1.gamma")
|
||||
@@ -861,7 +861,7 @@ def convert_vae_22():
|
||||
new_key = new_key.replace(".shortcut.weight", ".conv_shortcut.weight")
|
||||
elif ".shortcut.bias" in new_key:
|
||||
new_key = new_key.replace(".shortcut.bias", ".conv_shortcut.bias")
|
||||
|
||||
|
||||
# Handle resample blocks - change upsamples to upsampler and remove index
|
||||
elif "resample" in new_key or "time_conv" in new_key:
|
||||
# Change the second upsamples to upsampler and remove the index
|
||||
@@ -872,7 +872,7 @@ def convert_vae_22():
|
||||
# Remove the index (parts[4]) and change upsamples to upsampler
|
||||
new_parts = parts[:3] + ["upsampler"] + parts[5:]
|
||||
new_key = ".".join(new_parts)
|
||||
|
||||
|
||||
new_state_dict[new_key] = value
|
||||
else:
|
||||
# Keep other keys unchanged
|
||||
|
||||
Reference in New Issue
Block a user