From 261a448c6aa467810545d87499f6f3bed334754f Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 20 Dec 2022 02:07:15 +0100 Subject: [PATCH] Correct hf hub download (#1767) * allow model download when no internet * up * make style --- src/diffusers/modeling_utils.py | 3 ++- src/diffusers/pipeline_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index edc519db6e..6d934e6b30 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -26,6 +26,7 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R from requests import HTTPError from . import __version__ +from .hub_utils import HF_HUB_OFFLINE from .utils import ( CONFIG_NAME, DIFFUSERS_CACHE, @@ -376,7 +377,7 @@ class ModelMixin(torch.nn.Module): resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) output_loading_info = kwargs.pop("output_loading_info", False) - local_files_only = kwargs.pop("local_files_only", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 54358de217..23c4b29a53 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -33,7 +33,7 @@ from tqdm.auto import tqdm from .configuration_utils import ConfigMixin from .dynamic_modules_utils import get_class_from_dynamic_module -from .hub_utils import http_user_agent +from .hub_utils import HF_HUB_OFFLINE, http_user_agent from .modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT from .schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME from .utils import ( @@ -441,7 +441,7 @@ class DiffusionPipeline(ConfigMixin): resume_download = kwargs.pop("resume_download", False) force_download = kwargs.pop("force_download", False) proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) + local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE) use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) torch_dtype = kwargs.pop("torch_dtype", None)