From 717a956a02bc2f20e40e16770b5e98c94063131b Mon Sep 17 00:00:00 2001 From: chavinlo <85657083+chavinlo@users.noreply.github.com> Date: Tue, 7 Feb 2023 03:10:34 -0500 Subject: [PATCH] Create convert_vae_pt_to_diffusers.py (#2215) * Create convert_vae_pt_to_diffusers.py Just a simple script to convert VAE.pt files to diffusers format Tested with: https://huggingface.co/WarriorMama777/OrangeMixs/blob/main/VAEs/orangemix.vae.pt * Update convert_vae_pt_to_diffusers.py Forgot to add the function call * make style --------- Co-authored-by: Patrick von Platen Co-authored-by: chavinlo --- .../textual_inversion/textual_inversion.py | 6 +- .../textual_inversion/textual_inversion.py | 6 +- scripts/convert_vae_pt_to_diffusers.py | 149 ++++++++++++++++++ src/diffusers/utils/hub_utils.py | 6 +- 4 files changed, 158 insertions(+), 9 deletions(-) create mode 100644 scripts/convert_vae_pt_to_diffusers.py diff --git a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py index 0796f0a708..3a7e6a74a1 100644 --- a/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py +++ b/examples/research_projects/onnxruntime/textual_inversion/textual_inversion.py @@ -752,9 +752,9 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index d41198c369..285d9864c9 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -749,9 +749,9 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.arange(len(tokenizer)) != placeholder_token_id with torch.no_grad(): - accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[ - index_no_updates - ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder).get_input_embeddings().weight[index_no_updates] = ( + orig_embeds_params[index_no_updates] + ) # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py new file mode 100644 index 0000000000..e82ad85001 --- /dev/null +++ b/scripts/convert_vae_pt_to_diffusers.py @@ -0,0 +1,149 @@ +import argparse +import io + +import torch + +import requests +from diffusers import AutoencoderKL +from diffusers.pipelines.stable_diffusion.convert_from_ckpt import ( + assign_to_checkpoint, + conv_attn_to_linear, + create_vae_diffusers_config, + renew_vae_attention_paths, + renew_vae_resnet_paths, +) +from omegaconf import OmegaConf + + +def custom_convert_ldm_vae_checkpoint(checkpoint, config): + vae_state_dict = checkpoint + + new_checkpoint = {} + + new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"] + new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"] + new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"] + new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"] + new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"] + new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"] + + new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"] + new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"] + new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"] + new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"] + new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"] + new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"] + + new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"] + new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"] + new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"] + new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"] + + # Retrieves the keys for the encoder down blocks only + num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer}) + down_blocks = { + layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks) + } + + # Retrieves the keys for the decoder up blocks only + num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer}) + up_blocks = { + layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks) + } + + for i in range(num_down_blocks): + resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key] + + if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict: + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.weight" + ) + new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop( + f"encoder.down.{i}.downsample.conv.bias" + ) + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + + for i in range(num_up_blocks): + block_id = num_up_blocks - 1 - i + resnets = [ + key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key + ] + + if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict: + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.weight" + ] + new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[ + f"decoder.up.{block_id}.upsample.conv.bias" + ] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key] + num_mid_res_blocks = 2 + for i in range(1, num_mid_res_blocks + 1): + resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key] + + paths = renew_vae_resnet_paths(resnets) + meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + + mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key] + paths = renew_vae_attention_paths(mid_attentions) + meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"} + assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config) + conv_attn_to_linear(new_checkpoint) + return new_checkpoint + + +def vae_pt_to_vae_diffuser( + checkpoint_path: str, + output_path: str, +): + # Only support V1 + r = requests.get( + " https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml" + ) + io_obj = io.BytesIO(r.content) + + original_config = OmegaConf.load(io_obj) + image_size = 512 + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path, map_location=device) + + # Convert the VAE model. + vae_config = create_vae_diffusers_config(original_config, image_size=image_size) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config) + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(converted_vae_checkpoint) + vae.save_pretrained(output_path) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--vae_pt_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the VAE.pt to convert.") + + vae_pt_to_vae_diffuser(args.vae_pt_path, args.dump_path) diff --git a/src/diffusers/utils/hub_utils.py b/src/diffusers/utils/hub_utils.py index ea00058fe6..7e6bd7870d 100644 --- a/src/diffusers/utils/hub_utils.py +++ b/src/diffusers/utils/hub_utils.py @@ -112,9 +112,9 @@ def create_model_card(args, model_name): learning_rate=args.learning_rate, train_batch_size=args.train_batch_size, eval_batch_size=args.eval_batch_size, - gradient_accumulation_steps=args.gradient_accumulation_steps - if hasattr(args, "gradient_accumulation_steps") - else None, + gradient_accumulation_steps=( + args.gradient_accumulation_steps if hasattr(args, "gradient_accumulation_steps") else None + ), adam_beta1=args.adam_beta1 if hasattr(args, "adam_beta1") else None, adam_beta2=args.adam_beta2 if hasattr(args, "adam_beta2") else None, adam_weight_decay=args.adam_weight_decay if hasattr(args, "adam_weight_decay") else None,