From 5df2acf7d299346e2cb5ff921cb499ca774c6213 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 22 Jun 2023 13:52:59 +0200 Subject: [PATCH] [Conversion] Small fixes (#3848) * [Conversion] Small fixes * Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py --- scripts/convert_vae_pt_to_diffusers.py | 12 ++++++++++-- .../pipelines/stable_diffusion/convert_from_ckpt.py | 5 +++-- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/scripts/convert_vae_pt_to_diffusers.py b/scripts/convert_vae_pt_to_diffusers.py index 4762ffcf8d..a8ba48bc00 100644 --- a/scripts/convert_vae_pt_to_diffusers.py +++ b/scripts/convert_vae_pt_to_diffusers.py @@ -129,11 +129,19 @@ def vae_pt_to_vae_diffuser( original_config = OmegaConf.load(io_obj) image_size = 512 device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) + if checkpoint_path.endswith("safetensors"): + from safetensors import safe_open + + checkpoint = {} + with safe_open(checkpoint_path, framework="pt", device="cpu") as f: + for key in f.keys(): + checkpoint[key] = f.get_tensor(key) + else: + checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"] # Convert the VAE model. vae_config = create_vae_diffusers_config(original_config, image_size=image_size) - converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config) + converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config) vae = AutoencoderKL(**vae_config) vae.load_state_dict(converted_vae_checkpoint) diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 746844ea1e..3b3724f0d0 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -286,10 +286,11 @@ def create_unet_diffusers_config(original_config, image_size: int, controlnet=Fa "use_linear_projection": use_linear_projection, "class_embed_type": class_embed_type, "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim, - "conditioning_channels": unet_params.hint_channels, } - if not controlnet: + if controlnet: + config["conditioning_channels"] = unet_params.hint_channels + else: config["out_channels"] = unet_params.out_channels config["up_block_types"] = tuple(up_block_types)