From d75ad93ca73e00bc59d980004c4fbe798f598498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?cafe+ai=20=E2=80=94=20=E3=81=8B=E3=81=B5=E3=81=87=E3=81=82?= =?UTF-8?q?=E3=81=84?= <116491182+cafeai@users.noreply.github.com> Date: Mon, 23 Jan 2023 17:44:55 +0900 Subject: [PATCH] Safetensors loading in "convert_diffusers_to_original_stable_diffusion" (#2054) * Safetensors loading in "convert_diffusers_to_original_stable_diffusion" Adds diffusers format saftetensors loading support * Fix import sort order: convert_diffusers_to_original_stable_diffusion.py Co-authored-by: Patrick von Platen --- ..._diffusers_to_original_stable_diffusion.py | 33 ++++++++++++++----- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/scripts/convert_diffusers_to_original_stable_diffusion.py b/scripts/convert_diffusers_to_original_stable_diffusion.py index ffcea8332f..5ca9846914 100644 --- a/scripts/convert_diffusers_to_original_stable_diffusion.py +++ b/scripts/convert_diffusers_to_original_stable_diffusion.py @@ -8,7 +8,7 @@ import re import torch -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file # =================# @@ -278,23 +278,38 @@ if __name__ == "__main__": assert args.checkpoint_path is not None, "Must provide a checkpoint path!" - unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") - vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") - text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + # Path for safetensors + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.safetensors") + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.safetensors") + text_enc_path = osp.join(args.model_path, "text_encoder", "model.safetensors") + + # Load models from safetensors if it exists, if it doesn't pytorch + if osp.exists(unet_path): + unet_state_dict = load_file(unet_path, device="cpu") + else: + unet_path = osp.join(args.model_path, "unet", "diffusion_pytorch_model.bin") + unet_state_dict = torch.load(unet_path, map_location="cpu") + + if osp.exists(vae_path): + vae_state_dict = load_file(vae_path, device="cpu") + else: + vae_path = osp.join(args.model_path, "vae", "diffusion_pytorch_model.bin") + vae_state_dict = torch.load(vae_path, map_location="cpu") + + if osp.exists(text_enc_path): + text_enc_dict = load_file(text_enc_path, device="cpu") + else: + text_enc_path = osp.join(args.model_path, "text_encoder", "pytorch_model.bin") + text_enc_dict = torch.load(text_enc_path, map_location="cpu") # Convert the UNet model - unet_state_dict = torch.load(unet_path, map_location="cpu") unet_state_dict = convert_unet_state_dict(unet_state_dict) unet_state_dict = {"model.diffusion_model." + k: v for k, v in unet_state_dict.items()} # Convert the VAE model - vae_state_dict = torch.load(vae_path, map_location="cpu") vae_state_dict = convert_vae_state_dict(vae_state_dict) vae_state_dict = {"first_stage_model." + k: v for k, v in vae_state_dict.items()} - # Convert the text encoder model - text_enc_dict = torch.load(text_enc_path, map_location="cpu") - # Easiest way to identify v2.0 model seems to be that the text encoder (OpenCLIP) is deeper is_v20_model = "text_model.encoder.layers.22.layer_norm2.bias" in text_enc_dict