mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Update Ruff to latest Version (#10919)
* update * update * update * update
This commit is contained in:
@@ -468,7 +468,7 @@ def make_vqvae(old_vae):
|
||||
|
||||
# assert (old_output == new_output).all()
|
||||
print("skipping full vae equivalence check")
|
||||
print(f"vae full diff { (old_output - new_output).float().abs().sum()}")
|
||||
print(f"vae full diff {(old_output - new_output).float().abs().sum()}")
|
||||
|
||||
return new_vae
|
||||
|
||||
|
||||
@@ -239,7 +239,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.1"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.1"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
elif layer_type == "AttnUpBlock2D":
|
||||
for j in range(layers_per_block + 1):
|
||||
@@ -255,7 +255,7 @@ def con_pt_to_diffuser(checkpoint_path: str, unet_config):
|
||||
|
||||
if i != len(up_block_types) - 1:
|
||||
new_prefix = f"up_blocks.{i}.upsamplers.0"
|
||||
old_prefix = f"output_blocks.{current_layer-1}.2"
|
||||
old_prefix = f"output_blocks.{current_layer - 1}.2"
|
||||
new_checkpoint = convert_resnet(checkpoint, new_checkpoint, old_prefix, new_prefix)
|
||||
|
||||
new_checkpoint["conv_norm_out.weight"] = checkpoint["out.0.weight"]
|
||||
|
||||
@@ -261,9 +261,9 @@ def main(args):
|
||||
|
||||
model_name = args.model_path.split("/")[-1].split(".")[0]
|
||||
if not os.path.isfile(args.model_path):
|
||||
assert (
|
||||
model_name == args.model_path
|
||||
), f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
assert model_name == args.model_path, (
|
||||
f"Make sure to provide one of the official model names {MODELS_MAP.keys()}"
|
||||
)
|
||||
args.model_path = download(model_name)
|
||||
|
||||
sample_rate = MODELS_MAP[model_name]["sample_rate"]
|
||||
@@ -290,9 +290,9 @@ def main(args):
|
||||
assert all(k.endswith("kernel") for k in list(diffusers_minus_renamed)), f"Problem with {diffusers_minus_renamed}"
|
||||
|
||||
for key, value in renamed_state_dict.items():
|
||||
assert (
|
||||
diffusers_state_dict[key].squeeze().shape == value.squeeze().shape
|
||||
), f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
assert diffusers_state_dict[key].squeeze().shape == value.squeeze().shape, (
|
||||
f"Shape for {key} doesn't match. Diffusers: {diffusers_state_dict[key].shape} vs. {value.shape}"
|
||||
)
|
||||
if key == "time_proj.weight":
|
||||
value = value.squeeze()
|
||||
|
||||
|
||||
@@ -52,18 +52,18 @@ for i in range(3):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(4):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i < 2:
|
||||
@@ -75,12 +75,12 @@ for i in range(3):
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
unet_conversion_map_layer.append(("output_blocks.2.2.conv.", "output_blocks.2.1.conv."))
|
||||
|
||||
@@ -89,7 +89,7 @@ sd_mid_atn_prefix = "middle_block.1."
|
||||
unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -137,20 +137,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -47,36 +47,36 @@ for i in range(4):
|
||||
for j in range(2):
|
||||
# loop over resnets/attentions for downblocks
|
||||
hf_down_res_prefix = f"down_blocks.{i}.resnets.{j}."
|
||||
sd_down_res_prefix = f"input_blocks.{3*i + j + 1}.0."
|
||||
sd_down_res_prefix = f"input_blocks.{3 * i + j + 1}.0."
|
||||
unet_conversion_map_layer.append((sd_down_res_prefix, hf_down_res_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no attention layers in down_blocks.3
|
||||
hf_down_atn_prefix = f"down_blocks.{i}.attentions.{j}."
|
||||
sd_down_atn_prefix = f"input_blocks.{3*i + j + 1}.1."
|
||||
sd_down_atn_prefix = f"input_blocks.{3 * i + j + 1}.1."
|
||||
unet_conversion_map_layer.append((sd_down_atn_prefix, hf_down_atn_prefix))
|
||||
|
||||
for j in range(3):
|
||||
# loop over resnets/attentions for upblocks
|
||||
hf_up_res_prefix = f"up_blocks.{i}.resnets.{j}."
|
||||
sd_up_res_prefix = f"output_blocks.{3*i + j}.0."
|
||||
sd_up_res_prefix = f"output_blocks.{3 * i + j}.0."
|
||||
unet_conversion_map_layer.append((sd_up_res_prefix, hf_up_res_prefix))
|
||||
|
||||
if i > 0:
|
||||
# no attention layers in up_blocks.0
|
||||
hf_up_atn_prefix = f"up_blocks.{i}.attentions.{j}."
|
||||
sd_up_atn_prefix = f"output_blocks.{3*i + j}.1."
|
||||
sd_up_atn_prefix = f"output_blocks.{3 * i + j}.1."
|
||||
unet_conversion_map_layer.append((sd_up_atn_prefix, hf_up_atn_prefix))
|
||||
|
||||
if i < 3:
|
||||
# no downsample in down_blocks.3
|
||||
hf_downsample_prefix = f"down_blocks.{i}.downsamplers.0.conv."
|
||||
sd_downsample_prefix = f"input_blocks.{3*(i+1)}.0.op."
|
||||
sd_downsample_prefix = f"input_blocks.{3 * (i + 1)}.0.op."
|
||||
unet_conversion_map_layer.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
# no upsample in up_blocks.3
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"output_blocks.{3*i + 2}.{1 if i == 0 else 2}."
|
||||
sd_upsample_prefix = f"output_blocks.{3 * i + 2}.{1 if i == 0 else 2}."
|
||||
unet_conversion_map_layer.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
hf_mid_atn_prefix = "mid_block.attentions.0."
|
||||
@@ -85,7 +85,7 @@ unet_conversion_map_layer.append((sd_mid_atn_prefix, hf_mid_atn_prefix))
|
||||
|
||||
for j in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2*j}."
|
||||
sd_mid_res_prefix = f"middle_block.{2 * j}."
|
||||
unet_conversion_map_layer.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
@@ -133,20 +133,20 @@ for i in range(4):
|
||||
vae_conversion_map.append((sd_downsample_prefix, hf_downsample_prefix))
|
||||
|
||||
hf_upsample_prefix = f"up_blocks.{i}.upsamplers.0."
|
||||
sd_upsample_prefix = f"up.{3-i}.upsample."
|
||||
sd_upsample_prefix = f"up.{3 - i}.upsample."
|
||||
vae_conversion_map.append((sd_upsample_prefix, hf_upsample_prefix))
|
||||
|
||||
# up_blocks have three resnets
|
||||
# also, up blocks in hf are numbered in reverse from sd
|
||||
for j in range(3):
|
||||
hf_up_prefix = f"decoder.up_blocks.{i}.resnets.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3-i}.block.{j}."
|
||||
sd_up_prefix = f"decoder.up.{3 - i}.block.{j}."
|
||||
vae_conversion_map.append((sd_up_prefix, hf_up_prefix))
|
||||
|
||||
# this part accounts for mid blocks in both the encoder and the decoder
|
||||
for i in range(2):
|
||||
hf_mid_res_prefix = f"mid_block.resnets.{i}."
|
||||
sd_mid_res_prefix = f"mid.block_{i+1}."
|
||||
sd_mid_res_prefix = f"mid.block_{i + 1}."
|
||||
vae_conversion_map.append((sd_mid_res_prefix, hf_mid_res_prefix))
|
||||
|
||||
|
||||
|
||||
@@ -21,9 +21,9 @@ def main(args):
|
||||
model_config = HunyuanDiT2DControlNetModel.load_config(
|
||||
"Tencent-Hunyuan/HunyuanDiT-v1.2-Diffusers", subfolder="transformer"
|
||||
)
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
print(model_config)
|
||||
|
||||
for key in state_dict:
|
||||
|
||||
@@ -13,15 +13,14 @@ def main(args):
|
||||
state_dict = state_dict[args.load_key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
f"{args.load_key} not found in the checkpoint."
|
||||
f"Please load from the following keys:{state_dict.keys()}"
|
||||
f"{args.load_key} not found in the checkpoint.Please load from the following keys:{state_dict.keys()}"
|
||||
)
|
||||
|
||||
device = "cuda"
|
||||
model_config = HunyuanDiT2DModel.load_config("Tencent-Hunyuan/HunyuanDiT-Diffusers", subfolder="transformer")
|
||||
model_config[
|
||||
"use_style_cond_and_image_meta_size"
|
||||
] = args.use_style_cond_and_image_meta_size ### version <= v1.1: True; version >= v1.2: False
|
||||
model_config["use_style_cond_and_image_meta_size"] = (
|
||||
args.use_style_cond_and_image_meta_size
|
||||
) ### version <= v1.1: True; version >= v1.2: False
|
||||
|
||||
# input_size -> sample_size, text_dim -> cross_attention_dim
|
||||
for key in state_dict:
|
||||
|
||||
@@ -142,14 +142,14 @@ def block_to_diffusers_checkpoint(block, checkpoint, block_idx, block_type):
|
||||
diffusers_attention_prefix = f"{block_type}_blocks.{block_idx}.attentions.{attention_idx}"
|
||||
idx = n * attention_idx + 1 if block_type == "up" else n * attention_idx + 2
|
||||
self_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
cross_attention_index = 1 if not attention.add_self_attention else 2
|
||||
idx = (
|
||||
n * attention_idx + cross_attention_index
|
||||
if block_type == "up"
|
||||
else n * attention_idx + cross_attention_index + 1
|
||||
)
|
||||
cross_attention_prefix = f"{block_prefix}.{idx }"
|
||||
cross_attention_prefix = f"{block_prefix}.{idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
cross_attn_to_diffusers_checkpoint(
|
||||
@@ -220,9 +220,9 @@ def unet_model_from_original_config(original_config):
|
||||
|
||||
block_out_channels = original_config["channels"]
|
||||
|
||||
assert (
|
||||
len(set(original_config["depths"])) == 1
|
||||
), "UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
assert len(set(original_config["depths"])) == 1, (
|
||||
"UNet2DConditionModel currently do not support blocks with different number of layers"
|
||||
)
|
||||
layers_per_block = original_config["depths"][0]
|
||||
|
||||
class_labels_dim = original_config["mapping_cond_dim"]
|
||||
|
||||
@@ -168,28 +168,28 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.0.weight"
|
||||
f"blocks.0.{i + 1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.0.bias"
|
||||
f"blocks.0.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.2.weight"
|
||||
f"blocks.0.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.2.bias"
|
||||
f"blocks.0.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.3.weight"
|
||||
f"blocks.0.{i + 1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.3.bias"
|
||||
f"blocks.0.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.5.weight"
|
||||
f"blocks.0.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.0.{i+1}.stack.5.bias"
|
||||
f"blocks.0.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert up_blocks (MochiUpBlock3D)
|
||||
@@ -197,33 +197,35 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
for block in range(3):
|
||||
for i in range(down_block_layers[block]):
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.0.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv1.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.2.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.3.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.weight"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.resnets.{i}.conv2.conv.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.blocks.{i}.stack.5.bias"
|
||||
f"blocks.{block + 1}.blocks.{i}.stack.5.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.weight"] = decoder_state_dict.pop(
|
||||
f"blocks.{block+1}.proj.weight"
|
||||
f"blocks.{block + 1}.proj.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(
|
||||
f"blocks.{block + 1}.proj.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}up_blocks.{block}.proj.bias"] = decoder_state_dict.pop(f"blocks.{block+1}.proj.bias")
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
@@ -267,133 +269,133 @@ def convert_mochi_vae_state_dict_to_diffusers(encoder_ckpt_path, decoder_ckpt_pa
|
||||
# Convert block_in (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[0] = 3
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.weight"
|
||||
f"layers.{i + 1}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.0.bias"
|
||||
f"layers.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.weight"
|
||||
f"layers.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.2.bias"
|
||||
f"layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.weight"
|
||||
f"layers.{i + 1}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.3.bias"
|
||||
f"layers.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.weight"
|
||||
f"layers.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_in.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+1}.stack.5.bias"
|
||||
f"layers.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert down_blocks (MochiDownBlock3D)
|
||||
down_block_layers = [3, 4, 6] # layers_per_block[1], layers_per_block[2], layers_per_block[3]
|
||||
for block in range(3):
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.weight"
|
||||
f"layers.{block + 4}.layers.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.conv_in.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.0.bias"
|
||||
f"layers.{block + 4}.layers.0.bias"
|
||||
)
|
||||
|
||||
for i in range(down_block_layers[block]):
|
||||
# Convert resnets
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.0.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.0.weight")
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.0.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.2.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"] = (
|
||||
encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.stack.3.weight")
|
||||
)
|
||||
new_state_dict[
|
||||
f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.weight"
|
||||
] = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.stack.3.weight")
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.3.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.stack.5.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block+4}.layers.{i+1}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.attn.out.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.weight"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}down_blocks.{block}.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{block+4}.layers.{i+1}.attn_block.norm.bias"
|
||||
f"layers.{block + 4}.layers.{i + 1}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert block_out (MochiMidBlock3D)
|
||||
for i in range(3): # layers_per_block[-1] = 3
|
||||
# Convert resnets
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.weight"
|
||||
f"layers.{i + 7}.stack.0.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm1.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.0.bias"
|
||||
f"layers.{i + 7}.stack.0.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.weight"
|
||||
f"layers.{i + 7}.stack.2.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv1.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.2.bias"
|
||||
f"layers.{i + 7}.stack.2.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.weight"
|
||||
f"layers.{i + 7}.stack.3.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.norm2.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.3.bias"
|
||||
f"layers.{i + 7}.stack.3.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.weight"
|
||||
f"layers.{i + 7}.stack.5.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.resnets.{i}.conv2.conv.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.stack.5.bias"
|
||||
f"layers.{i + 7}.stack.5.bias"
|
||||
)
|
||||
|
||||
# Convert attentions
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i+7}.attn_block.attn.qkv.weight")
|
||||
qkv_weight = encoder_state_dict.pop(f"layers.{i + 7}.attn_block.attn.qkv.weight")
|
||||
q, k, v = qkv_weight.chunk(3, dim=0)
|
||||
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_q.weight"] = q
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_k.weight"] = k
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_v.weight"] = v
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.weight"
|
||||
f"layers.{i + 7}.attn_block.attn.out.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.attentions.{i}.to_out.0.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.attn.out.bias"
|
||||
f"layers.{i + 7}.attn_block.attn.out.bias"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.weight"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.weight"
|
||||
f"layers.{i + 7}.attn_block.norm.weight"
|
||||
)
|
||||
new_state_dict[f"{prefix}block_out.norms.{i}.norm_layer.bias"] = encoder_state_dict.pop(
|
||||
f"layers.{i+7}.attn_block.norm.bias"
|
||||
f"layers.{i + 7}.attn_block.norm.bias"
|
||||
)
|
||||
|
||||
# Convert output layers
|
||||
|
||||
@@ -662,7 +662,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -636,7 +636,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -642,7 +642,7 @@ def convert_open_clap_checkpoint(checkpoint):
|
||||
# replace sequential layers with list
|
||||
sequential_layer = re.match(sequential_layers_pattern, key).group(1)
|
||||
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer)//3}.linear.")
|
||||
key = key.replace(f"sequential.{sequential_layer}.", f"layers.{int(sequential_layer) // 3}.linear.")
|
||||
elif re.match(text_projection_pattern, key):
|
||||
projecton_layer = int(re.match(text_projection_pattern, key).group(1))
|
||||
|
||||
|
||||
@@ -95,18 +95,18 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
# get idx of the layer
|
||||
idx = int(new_key.split("coder.layers.")[1].split(".")[0])
|
||||
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx-1}")
|
||||
new_key = new_key.replace(f"coder.layers.{idx}", f"coder.block.{idx - 1}")
|
||||
|
||||
if "encoder" in new_key:
|
||||
for i in range(3):
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i+1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.3", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.4", f"block.{idx-1}.conv1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i + 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.3", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.4", f"block.{idx - 1}.conv1")
|
||||
else:
|
||||
for i in range(2, 5):
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.{i}", f"block.{idx-1}.res_unit{i-1}")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.0", f"block.{idx-1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx-1}.layers.1", f"block.{idx-1}.conv_t1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.{i}", f"block.{idx - 1}.res_unit{i - 1}")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.0", f"block.{idx - 1}.snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}.layers.1", f"block.{idx - 1}.conv_t1")
|
||||
|
||||
new_key = new_key.replace("layers.0.beta", "snake1.beta")
|
||||
new_key = new_key.replace("layers.0.alpha", "snake1.alpha")
|
||||
@@ -118,9 +118,9 @@ def convert_stable_audio_state_dict_to_diffusers(state_dict, num_autoencoder_lay
|
||||
new_key = new_key.replace("layers.3.weight_", "conv2.weight_")
|
||||
|
||||
if idx == num_autoencoder_layers + 1:
|
||||
new_key = new_key.replace(f"block.{idx-1}", "snake1")
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "snake1")
|
||||
elif idx == num_autoencoder_layers + 2:
|
||||
new_key = new_key.replace(f"block.{idx-1}", "conv2")
|
||||
new_key = new_key.replace(f"block.{idx - 1}", "conv2")
|
||||
|
||||
else:
|
||||
new_key = new_key
|
||||
|
||||
@@ -381,9 +381,9 @@ def convert_ldm_unet_checkpoint(
|
||||
|
||||
# TODO resnet time_mixer.mix_factor
|
||||
if f"input_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
new_checkpoint[f"down_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"input_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
|
||||
if len(attentions):
|
||||
paths = renew_attention_paths(attentions)
|
||||
@@ -478,9 +478,9 @@ def convert_ldm_unet_checkpoint(
|
||||
)
|
||||
|
||||
if f"output_blocks.{i}.0.time_mixer.mix_factor" in unet_state_dict:
|
||||
new_checkpoint[
|
||||
f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"
|
||||
] = unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
new_checkpoint[f"up_blocks.{block_id}.resnets.{layer_in_block_id}.time_mixer.mix_factor"] = (
|
||||
unet_state_dict[f"output_blocks.{i}.0.time_mixer.mix_factor"]
|
||||
)
|
||||
|
||||
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
||||
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
||||
|
||||
@@ -51,9 +51,9 @@ PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchV
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert (
|
||||
original_config["target"] in PORTED_VQVAES
|
||||
), f"{original_config['target']} has not yet been ported to diffusers."
|
||||
assert original_config["target"] in PORTED_VQVAES, (
|
||||
f"{original_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
|
||||
original_config = original_config["params"]
|
||||
|
||||
@@ -464,15 +464,15 @@ PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_ima
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config["target"] in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config["target"] in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
assert original_diffusion_config["target"] in PORTED_DIFFUSIONS, (
|
||||
f"{original_diffusion_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_transformer_config["target"] in PORTED_TRANSFORMERS, (
|
||||
f"{original_transformer_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
assert original_content_embedding_config["target"] in PORTED_CONTENT_EMBEDDINGS, (
|
||||
f"{original_content_embedding_config['target']} has not yet been ported to diffusers."
|
||||
)
|
||||
|
||||
original_diffusion_config = original_diffusion_config["params"]
|
||||
original_transformer_config = original_transformer_config["params"]
|
||||
|
||||
Reference in New Issue
Block a user