mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -11,6 +11,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
import inspect
|
||||
from contextlib import nullcontext
|
||||
from io import BytesIO
|
||||
@@ -30,7 +31,7 @@ from ..utils import (
|
||||
logging,
|
||||
)
|
||||
from ..utils.import_utils import BACKENDS_MAPPING
|
||||
from .single_file_utils import download_from_original_stable_diffusion_ckpt
|
||||
from .single_file_utils import download_from_original_stable_diffusion_ckpt, fetch_original_config
|
||||
|
||||
|
||||
if is_transformers_available():
|
||||
@@ -72,7 +73,7 @@ def check_valid_url(pretrained_model_link_or_path):
|
||||
return has_valid_url_prefix
|
||||
|
||||
|
||||
def fetch_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None):
|
||||
def download_model_checkpoint(ckpt_path, cache_dir=None, resume_download=False, force_download=False, proxies=None, local_files_only=None, token=None, revision=None):
|
||||
# get repo_id and (potentially nested) file path of ckpt in repo
|
||||
repo_id = "/".join(ckpt_path.parts[:2])
|
||||
file_path = "/".join(ckpt_path.parts[2:])
|
||||
@@ -119,7 +120,14 @@ def infer_model_type(pipeline_class_name):
|
||||
return MODEL_TYPE_FROM_PIPELINE_CLASS.get(pipeline_class_name, None)
|
||||
|
||||
|
||||
def build_component(component_name, **kwargs):
|
||||
def build_component(component_name, original_config, checkpoint, **kwargs):
|
||||
if component_name in kwargs:
|
||||
return kwargs.pop(component_name, None)
|
||||
|
||||
component_class = getattr(importlib.import_module("diffusers"), component_name)
|
||||
|
||||
|
||||
|
||||
return
|
||||
|
||||
|
||||
@@ -310,14 +318,21 @@ class FromSingleFileMixin:
|
||||
raise ValueError(
|
||||
f"The provided path is either not a file or a valid huggingface URL was not provided. Valid URLs begin with {', '.join(VALID_URL_PREFIXES)}"
|
||||
)
|
||||
pretrained_model_link_or_path = fetch_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision)
|
||||
pretrained_model_link_or_path = download_model_checkpoint(ckpt_path, cache_dir=cache_dir, resume_download=resume_download, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision)
|
||||
checkpoint = load_checkpoint(pretrained_model_link_or_path, from_safetensors=from_safetensors)
|
||||
global_step = checkpoint["global_step"] if "global_step" in checkpoint else None
|
||||
|
||||
# NOTE: this while loop isn't great but this controlnet checkpoint has one additional
|
||||
# "state_dict" key https://huggingface.co/thibaud/controlnet-canny-sd21
|
||||
while "state_dict" in checkpoint:
|
||||
checkpoint = checkpoint["state_dict"]
|
||||
|
||||
original_config = fetch_original_config(checkpoint, config_files)
|
||||
component_names = extract_pipeline_component_names(cls)
|
||||
|
||||
pipeline_components = {}
|
||||
for component in component_names:
|
||||
pipeline_components[component] = build_component(component, checkpoint, **kwargs)
|
||||
pipeline_components[component] = build_component(component, checkpoint, original_config, **kwargs)
|
||||
|
||||
pipe = download_from_original_stable_diffusion_ckpt(
|
||||
pretrained_model_link_or_path,
|
||||
@@ -344,6 +359,8 @@ class FromSingleFileMixin:
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
|
||||
pipe = cls(**pipeline_components, **kwargs)
|
||||
|
||||
if torch_dtype is not None:
|
||||
pipe.to(dtype=torch_dtype)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user