From 8b7eecd4d4bc196f916b737e518fb93d1f625ff7 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Tue, 26 Dec 2023 03:13:43 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 2369362777..9f5c3d48d1 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import importlib import inspect from contextlib import nullcontext from io import BytesIO @@ -30,7 +31,7 @@ from ..utils import ( logging, ) from ..utils.import_utils import BACKENDS_MAPPING -from .single_file_utils import download_from_original_stable_diffusion_ckpt +from .single_file_utils import download_from_original_stable_diffusion_ckpt, fetch_original_config if is_transformers_available(): @@ -72,7 +73,7 @@ def check_valid_url(pretrained_model_link_or_path): return has_valid_url_prefix -def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None): +def download_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None): # get repo_id and (potentially nested) file path of ckpt in repo repo_id = "/".join(ckpt_path.parts[:2]) file_path = "/".join(ckpt_path.parts[2:]) @@ -119,7 +120,14 @@ def infer_model_type(pipeline_class_name): return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None) -def build_component(component_name, **kwargs): +def build_component(component_name, original_config, checkpoint, **kwargs): + if component_name in kwargs: + return kwargs.pop(component_name, None) + + component_class = getattr(importlib.import_module("diffusers"), component_name) + + + return @@ -310,14 +318,21 @@ class FromSingleFileMixin: raise ValueError( f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}" ) - pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision) + pretrained_model_link_or_path = download_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision) checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) + global_step = checkpoint["global_step"] if "global_step" in checkpoint else None + # NOTE: this while loop isn't great but this controlnet checkpoint has one additional + # "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21 + while "state_dict" in checkpoint: + checkpoint = checkpoint["state_dict"] + + original_config = fetch_original_config(checkpoint, config_files) component_names = extract_pipeline_component_names(cls) pipeline_components = {} for component in component_names: - pipeline_components[component] = build_component(component, checkpoint, **kwargs) + pipeline_components[component] = build_component(component, checkpoint, original_config, **kwargs) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path, @@ -344,6 +359,8 @@ class FromSingleFileMixin: local_files_only=local_files_only, ) + pipe = cls(**pipeline_components, **kwargs) + if torch_dtype is not None: pipe.to(dtype=torch_dtype)