mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -20,6 +20,7 @@ import requests
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import validate_hf_hub_args
|
||||
from safetensors.torch import load_file as safe_load
|
||||
|
||||
from ..utils import (
|
||||
deprecate,
|
||||
@@ -55,7 +56,7 @@ MODEL_TYPE_FROM_PIPELINE_CLASS = {
|
||||
}
|
||||
|
||||
|
||||
def extract_pipeline_compoments(pipeline_class):
|
||||
def extract_pipeline_component_names(pipeline_class):
|
||||
components = inspect.signature(pipeline_class).parameters.keys()
|
||||
return components
|
||||
|
||||
@@ -97,10 +98,31 @@ def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, for
|
||||
return path
|
||||
|
||||
|
||||
def load_checkpoint(checkpoint_path_or_dict, device=None, from_safetensors=True):
|
||||
if device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
if isinstance(checkpoint_path_or_dict, str):
|
||||
if from_safetensors:
|
||||
checkpoint = safe_load(checkpoint_path_or_dict, device="cpu")
|
||||
|
||||
else:
|
||||
checkpoint = torch.load(checkpoint_path_or_dict, map_location=device)
|
||||
|
||||
elif isinstance(checkpoint_path_or_dict, dict):
|
||||
checkpoint = checkpoint_path_or_dict
|
||||
|
||||
return checkpoint
|
||||
|
||||
|
||||
def infer_model_type(pipeline_class_name):
|
||||
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
|
||||
|
||||
|
||||
def build_component(component_name, **kwargs):
|
||||
return
|
||||
|
||||
|
||||
class FromSingleFileMixin:
|
||||
"""
|
||||
Load model weights saved in the `.ckpt` format into a [`DiffusionPipeline`].
|
||||
@@ -289,7 +311,13 @@ class FromSingleFileMixin:
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
|
||||
)
|
||||
pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision)
|
||||
components = extract_pipeline_compoments(cls)
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
|
||||
component_names = extract_pipeline_component_names(cls)
|
||||
|
||||
pipeline_components = {}
|
||||
for component in component_names:
|
||||
pipeline_components[component] = build_component(component, checkpoint, **kwargs)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
|
||||
Reference in New Issue
Block a user