mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Allow passing a checkpoint state_dict to convert_from_ckpt (instead of just a string path) (#4653)
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
import re
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
from typing import Optional, Union, Dict
|
||||
|
||||
import requests
|
||||
import torch
|
||||
@@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint(
|
||||
|
||||
|
||||
def download_from_original_stable_diffusion_ckpt(
|
||||
checkpoint_path: str,
|
||||
checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]],
|
||||
original_config_file: str = None,
|
||||
image_size: Optional[int] = None,
|
||||
prediction_type: str = None,
|
||||
@@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
recommended that you override the default values and/or supply an `original_config_file` wherever possible.
|
||||
|
||||
Args:
|
||||
checkpoint_path (`str`): Path to `.ckpt` file.
|
||||
checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict.
|
||||
original_config_file (`str`):
|
||||
Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically
|
||||
inferred by looking for a key that only exists in SD2.0 models.
|
||||
@@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
|
||||
if from_safetensors:
|
||||
from safetensors.torch import load_file as safe_load
|
||||
if isinstance(checkpoint_path_or_dict, str):
|
||||
if from_safetensors:
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
checkpoint = safe_load(checkpoint_path, device="cpu")
|
||||
else:
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path, map_location=device)
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
|
||||
elif isinstance(checkpoint_path_or_dict, dict):
|
||||
checkpoint = checkpoint_path_or_dict
|
||||
|
||||
# Sometimes models don't have the global_step item
|
||||
if "global_step" in checkpoint:
|
||||
@@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
image_size = 512
|
||||
|
||||
if controlnet is None and "control_stage_config" in original_config.model.params:
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
controlnet = convert_controlnet_checkpoint(
|
||||
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
||||
checkpoint, original_config, path, image_size, upcast_attention, extract_ema
|
||||
)
|
||||
|
||||
num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000
|
||||
@@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt(
|
||||
# Convert the UNet2DConditionModel model.
|
||||
unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
|
||||
unet_config["upcast_attention"] = upcast_attention
|
||||
path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else ""
|
||||
converted_unet_checkpoint = convert_ldm_unet_checkpoint(
|
||||
checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema
|
||||
checkpoint, unet_config, path=path, extract_ema=extract_ema
|
||||
)
|
||||
|
||||
ctx = init_empty_weights if is_accelerate_available() else nullcontext
|
||||
|
||||
Reference in New Issue
Block a user