From daf4d05b1fccf84af21d7ce1db1e86ff46d45a24 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Mon, 25 Dec 2023 05:49:40 +0000 Subject: [PATCH] update --- src/diffusers/loaders/single_file.py | 32 ++++++++++++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/src/diffusers/loaders/single_file.py b/src/diffusers/loaders/single_file.py index 451592763e..2369362777 100644 --- a/src/diffusers/loaders/single_file.py +++ b/src/diffusers/loaders/single_file.py @@ -20,6 +20,7 @@ import requests import torch from huggingface_hub import hf_hub_download from huggingface_hub.utils import validate_hf_hub_args +from safetensors.torch import load_file as safe_load from ..utils import ( deprecate, @@ -55,7 +56,7 @@ MODEL_TYPE_FROM_PIPELINE_CLASS = { } -def extract_pipeline_compoments(pipeline_class): +def extract_pipeline_component_names(pipeline_class): components = inspect.signature(pipeline_class).parameters.keys() return components @@ -97,10 +98,31 @@ def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, for return path +def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True): + if device is None: + device = "cuda" if torch.cuda.is_available() else "cpu" + + if isinstance(checkpoint_path_or_dict, str): + if from_safetensors: + checkpoint = safe_load(checkpoint_path_or_dict, device="cpu") + + else: + checkpoint = torch.load(checkpoint_path_or_dict, map_location=device) + + elif isinstance(checkpoint_path_or_dict, dict): + checkpoint = checkpoint_path_or_dict + + return checkpoint + + def infer_model_type(pipeline_class_name): return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None) +def build_component(component_name, **kwargs): + return + + class FromSingleFileMixin: """ Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`]. @@ -289,7 +311,13 @@ class FromSingleFileMixin: f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}" ) pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision) - components = extract_pipeline_compoments(cls) + checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors) + + component_names = extract_pipeline_component_names(cls) + + pipeline_components = {} + for component in component_names: + pipeline_components[component] = build_component(component, checkpoint, **kwargs) pipe = download_from_original_stable_diffusion_ckpt( pretrained_model_link_or_path,