mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[WIP] Modular Diffusers support custom code/pipeline blocks (#11539)
* update * update
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -15,13 +15,16 @@
|
||||
"""Utilities to dynamically load objects from the Hub."""
|
||||
|
||||
import importlib
|
||||
import signal
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Dict, Optional, Union
|
||||
from urllib import request
|
||||
|
||||
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
# See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
|
||||
COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
|
||||
TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
|
||||
_HF_REMOTE_CODE_LOCK = threading.Lock()
|
||||
|
||||
|
||||
def get_diffusers_versions():
|
||||
@@ -154,15 +159,87 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path):
|
||||
def _raise_timeout_error(signum, frame):
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||
)
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||
if trust_remote_code is None:
|
||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||
prev_sig_handler = None
|
||||
try:
|
||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
if answer.lower() in ["yes", "y", "1"]:
|
||||
trust_remote_code = True
|
||||
elif answer.lower() in ["no", "n", "0", ""]:
|
||||
trust_remote_code = False
|
||||
signal.alarm(0)
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
if prev_sig_handler is not None:
|
||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||
signal.alarm(0)
|
||||
elif has_remote_code:
|
||||
# For the CI which puts the timeout at 0
|
||||
_raise_timeout_error(None, None)
|
||||
|
||||
if has_remote_code and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
|
||||
def get_class_in_module(class_name, module_path, force_reload=False):
|
||||
"""
|
||||
Import a module on the cache directory for modules and extract a class from it.
|
||||
"""
|
||||
module_path = module_path.replace(os.path.sep, ".")
|
||||
module = importlib.import_module(module_path)
|
||||
name = os.path.normpath(module_path)
|
||||
if name.endswith(".py"):
|
||||
name = name[:-3]
|
||||
name = name.replace(os.path.sep, ".")
|
||||
module_file: Path = Path(HF_MODULES_CACHE) / module_path
|
||||
|
||||
with _HF_REMOTE_CODE_LOCK:
|
||||
if force_reload:
|
||||
sys.modules.pop(name, None)
|
||||
importlib.invalidate_caches()
|
||||
cached_module: Optional[ModuleType] = sys.modules.get(name)
|
||||
module_spec = importlib.util.spec_from_file_location(name, location=module_file)
|
||||
|
||||
module: ModuleType
|
||||
if cached_module is None:
|
||||
module = importlib.util.module_from_spec(module_spec)
|
||||
# insert it into sys.modules before any loading begins
|
||||
sys.modules[name] = module
|
||||
else:
|
||||
module = cached_module
|
||||
|
||||
module_spec.loader.exec_module(module)
|
||||
|
||||
if class_name is None:
|
||||
return find_pipeline_class(module)
|
||||
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
@@ -454,4 +531,4 @@ def get_class_from_dynamic_module(
|
||||
revision=revision,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
return get_class_in_module(class_name, final_module.replace(".py", ""))
|
||||
return get_class_in_module(class_name, final_module)
|
||||
|
||||
Reference in New Issue
Block a user