diff --git a/src/diffusers/utils/constants.py b/src/diffusers/utils/constants.py index 44d383814c..d9867fb875 100644 --- a/src/diffusers/utils/constants.py +++ b/src/diffusers/utils/constants.py @@ -45,7 +45,7 @@ DIFFUSERS_ATTN_BACKEND = os.getenv("DIFFUSERS_ATTN_BACKEND", "native") DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES DEFAULT_HF_PARALLEL_LOADING_WORKERS = 8 HF_ENABLE_PARALLEL_LOADING = os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES -DIFFUSERS_DISABLE_CUSTOM_CODE = os.getenv("DIFFUSERS_DISABLE_CUSTOM_CODE", "false").lower() in ENV_VARS_TRUE_VALUES +DIFFUSERS_DISABLE_REMOTE_CODE = os.getenv("DIFFUSERS_DISABLE_REMOTE_CODE", "false").lower() in ENV_VARS_TRUE_VALUES # Below should be `True` if the current version of `peft` and `transformers` are compatible with # PEFT backend. Will automatically fall back to PEFT backend if the correct versions of the libraries are diff --git a/src/diffusers/utils/dynamic_modules_utils.py b/src/diffusers/utils/dynamic_modules_utils.py index 28f220a877..08cd4ca1eb 100644 --- a/src/diffusers/utils/dynamic_modules_utils.py +++ b/src/diffusers/utils/dynamic_modules_utils.py @@ -168,11 +168,23 @@ def _raise_timeout_error(signum, frame): def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code): trust_remote_code = trust_remote_code and not DIFFUSERS_DISABLE_REMOTE_CODE + if DIFFUSERS_DISABLE_REMOTE_CODE: + logger.warning( + "Remote code execution has been disabled globally via DIFFUSERS_DISABLE_REMOTE_CODE environment variable. Ignoring `trust_remote_code`." + ) + if has_remote_code and not trust_remote_code: - raise ValueError( - f"The repository for {model_name} contains custom code which must be executed to correctly " - f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n" - f"Please pass the argument `trust_remote_code=True` to allow custom code to be run." + error_msg = f"The repository for {model_name} contains custom code. " + error_msg += ( + "Remote code is disabled globally via DIFFUSERS_DISABLE_REMOTE_CODE." + if DIFFUSERS_DISABLE_REMOTE_CODE + else "Pass `trust_remote_code=True` to allow loading remote code modules." + ) + raise ValueError(error_msg) + + elif has_remote_code and trust_remote_code: + logger.warning( + f"`trust_remote_code` is enabled. Downloading code from {model_name}. Please ensure you trust the contents of this repository" ) return trust_remote_code