diff --git a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py index 0a534d4193..ea80167644 100644 --- a/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +++ b/src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py @@ -17,7 +17,7 @@ import re from contextlib import nullcontext from io import BytesIO -from typing import Optional +from typing import Optional, Union, Dict import requests import torch @@ -1111,7 +1111,7 @@ def convert_controlnet_checkpoint( def download_from_original_stable_diffusion_ckpt( - checkpoint_path: str, + checkpoint_path_or_dict: Union[str, Dict[str, torch.Tensor]], original_config_file: str = None, image_size: Optional[int] = None, prediction_type: str = None, @@ -1144,7 +1144,7 @@ def download_from_original_stable_diffusion_ckpt( recommended that you override the default values and/or supply an `original_config_file` wherever possible. Args: - checkpoint_path (`str`): Path to `.ckpt` file. + checkpoint_path_or_dict (`str` or `dict`): Path to `.ckpt` file, or the state dict. original_config_file (`str`): Path to `.yaml` config file corresponding to the original architecture. If `None`, will be automatically inferred by looking for a key that only exists in SD2.0 models. @@ -1226,16 +1226,19 @@ def download_from_original_stable_diffusion_ckpt( from omegaconf import OmegaConf - if from_safetensors: - from safetensors.torch import load_file as safe_load + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + from safetensors.torch import load_file as safe_load - checkpoint = safe_load(checkpoint_path, device="cpu") - else: - if device is None: - device = "cuda" if torch.cuda.is_available() else "cpu" - checkpoint = torch.load(checkpoint_path, map_location=device) + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") else: - checkpoint = torch.load(checkpoint_path, map_location=device) + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict # Sometimes models don't have the global_step item if "global_step" in checkpoint: @@ -1318,8 +1321,9 @@ def download_from_original_stable_diffusion_ckpt( image_size = 512 if controlnet is None and "control_stage_config" in original_config.model.params: + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" controlnet = convert_controlnet_checkpoint( - checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema + checkpoint, original_config, path, image_size, upcast_attention, extract_ema ) num_train_timesteps = getattr(original_config.model.params, "timesteps", None) or 1000 @@ -1378,8 +1382,9 @@ def download_from_original_stable_diffusion_ckpt( # Convert the UNet2DConditionModel model. unet_config = create_unet_diffusers_config(original_config, image_size=image_size) unet_config["upcast_attention"] = upcast_attention + path = checkpoint_path_or_dict if isinstance(checkpoint_path_or_dict, str) else "" converted_unet_checkpoint = convert_ldm_unet_checkpoint( - checkpoint, unet_config, path=checkpoint_path, extract_ema=extract_ema + checkpoint, unet_config, path=path, extract_ema=extract_ema ) ctx = init_empty_weights if is_accelerate_available() else nullcontext