mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Conversion] Small fixes (#3848)
* [Conversion] Small fixes * Update src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py
This commit is contained in:
committed by
GitHub
parent
88d269461c
commit
5df2acf7d2
@@ -129,11 +129,19 @@ def vae_pt_to_vae_diffuser(
|
||||
original_config = OmegaConf.load(io_obj)
|
||||
image_size = 512
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
if checkpoint_path.endswith("safetensors"):
|
||||
from safetensors import safe_open
|
||||
|
||||
checkpoint = {}
|
||||
with safe_open(checkpoint_path, framework="pt", device="cpu") as f:
|
||||
for key in f.keys():
|
||||
checkpoint[key] = f.get_tensor(key)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)["state_dict"]
|
||||
|
||||
# Convert the VAE model.
|
||||
vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
|
||||
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint["state_dict"], vae_config)
|
||||
converted_vae_checkpoint = custom_convert_ldm_vae_checkpoint(checkpoint, vae_config)
|
||||
|
||||
vae = AutoencoderKL(**vae_config)
|
||||
vae.load_state_dict(converted_vae_checkpoint)
|
||||
|
||||
Reference in New Issue
Block a user