1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Allow passing a checkpoint state_dict to convert_from_ckpt (instead of just a string path) (#4653)

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
cmdr2
2023-08-26 00:20:39 +05:30
committed by GitHub
parent b7b1a30bc4
commit cb432c4ebc

View File

@@ -17,7 +17,7 @@
import re
from contextlib import nullcontext
from io import BytesIO
from typing import Optional
from typing import Optional, Union, Dict
import requests
import torch
@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint(
def download_from_original_stable_diffusion_ckpt(
checkpoint_path: str,
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
original_config_file: str = None,
image_size: Optional[int] = None,
prediction_type: str = None,
@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt(
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
Args:
checkpoint_path (`str`): Path to `.ckpt` file.
checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
original_config_file (`str`):
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
inferred by looking for a key that only exists in SD2.0 models.
@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt(
from omegaconf import OmegaConf
if from_safetensors:
from safetensors.torch import load_file as safe_load
if isinstance(checkpoint_path_or_dict, str):
if from_safetensors:
from safetensors.torch import load_file as safe_load
checkpoint = safe_load(checkpoint_path, device="cpu")
else:
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path, map_location=device)
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
else:
checkpoint = torch.load(checkpoint_path, map_location=device)
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
else:
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
elif isinstance(checkpoint_path_or_dict, dict):
checkpoint = checkpoint_path_or_dict
# Sometimes models don't have the global_step item
if "global_step" in checkpoint:
@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt(
image_size = 512
if controlnet is None and "control_stage_config" in original_config.model.params:
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
controlnet = convert_controlnet_checkpoint(
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
)
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt(
# Convert the UNet2DConditionModel model.
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
unet_config["upcast_attention"] = upcast_attention
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
checkpoint, unet_config, path=path, extract_ema=extract_ema
)
ctx = init_empty_weights if is_accelerate_available() else nullcontext