mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Modular] Provide option to disable custom code loading globally via env variable (#12177)
* update * update * update * update
This commit is contained in:
@@ -299,7 +299,7 @@ class ModularPipelineBlocks(ConfigMixin, PushToHubMixin):
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
trust_remote_code: Optional[bool] = None,
|
||||
trust_remote_code: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
hub_kwargs_names = [
|
||||
|
||||
@@ -45,6 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native")
|
||||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
|
||||
DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8
|
||||
HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
|
||||
DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES
|
||||
|
||||
# Below should be `True` if the current version of `peft` and `transformers` are compatible with
|
||||
# PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are
|
||||
|
||||
@@ -20,7 +20,6 @@ import json
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import signal
|
||||
import sys
|
||||
import threading
|
||||
from pathlib import Path
|
||||
@@ -34,6 +33,7 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from . import DIFFUSERS_DYNAMIC_MODULE_NAME, HF_MODULES_CACHE, logging
|
||||
from .constants import DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -159,52 +159,25 @@ def check_imports(filename):
|
||||
return get_relative_imports(filename)
|
||||
|
||||
|
||||
def _raise_timeout_error(signum, frame):
|
||||
raise ValueError(
|
||||
"Loading this model requires you to execute custom code contained in the model repository on your local "
|
||||
"machine. Please set the option `trust_remote_code=True` to permit loading of this model."
|
||||
)
|
||||
|
||||
|
||||
def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
|
||||
if trust_remote_code is None:
|
||||
if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
|
||||
prev_sig_handler = None
|
||||
try:
|
||||
prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
|
||||
signal.alarm(TIME_OUT_REMOTE_CODE)
|
||||
while trust_remote_code is None:
|
||||
answer = input(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
|
||||
f"Do you wish to run the custom code? [y/N] "
|
||||
)
|
||||
if answer.lower() in ["yes", "y", "1"]:
|
||||
trust_remote_code = True
|
||||
elif answer.lower() in ["no", "n", "0", ""]:
|
||||
trust_remote_code = False
|
||||
signal.alarm(0)
|
||||
except Exception:
|
||||
# OS which does not support signal.SIGALRM
|
||||
raise ValueError(
|
||||
f"The repository for {model_name} contains custom code which must be executed to correctly "
|
||||
f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
|
||||
f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
|
||||
)
|
||||
finally:
|
||||
if prev_sig_handler is not None:
|
||||
signal.signal(signal.SIGALRM, prev_sig_handler)
|
||||
signal.alarm(0)
|
||||
elif has_remote_code:
|
||||
# For the CI which puts the timeout at 0
|
||||
_raise_timeout_error(None, None)
|
||||
trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
if DIFFUSERS_DISABLE_REMOTE_CODE:
|
||||
logger.warning(
|
||||
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`."
|
||||
)
|
||||
|
||||
if has_remote_code and not trust_remote_code:
|
||||
raise ValueError(
|
||||
f"Loading {model_name} requires you to execute the configuration file in that"
|
||||
" repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
|
||||
" set the option `trust_remote_code=True` to remove this error."
|
||||
error_msg = f"The repository for {model_name} contains custom code. "
|
||||
error_msg += (
|
||||
"Downloading remote code is disabled globally via the DIFFUSERS_DISABLE_REMOTE_CODE environment variable."
|
||||
if DIFFUSERS_DISABLE_REMOTE_CODE
|
||||
else "Pass `trust_remote_code=True` to allow loading remote code modules."
|
||||
)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
elif has_remote_code and trust_remote_code:
|
||||
logger.warning(
|
||||
f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository"
|
||||
)
|
||||
|
||||
return trust_remote_code
|
||||
|
||||
Reference in New Issue
Block a user