mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
[From pretrained] Speed-up loading from cache (#2515)
* [From pretrained] Speed-up loading from cache * up * Fix more * fix one more bug * make style * bigger refactor * factor out function * Improve more * better * deprecate return cache folder * clean up * improve tests * up * upload * add nice tests * simplify * finish * correct * fix version * rename * Apply suggestions from code review Co-authored-by: Lucain <lucainp@gmail.com> * rename * correct doc string * correct more * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * apply code suggestions * finish --------- Co-authored-by: Lucain <lucainp@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
committed by
GitHub
parent
7fe638c502
commit
d761b58bfc
@@ -31,7 +31,15 @@ from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, R
|
||||
from requests import HTTPError
|
||||
|
||||
from . import __version__
|
||||
from .utils import DIFFUSERS_CACHE, HUGGINGFACE_CO_RESOLVE_ENDPOINT, DummyObject, deprecate, logging
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
DummyObject,
|
||||
deprecate,
|
||||
extract_commit_hash,
|
||||
http_user_agent,
|
||||
logging,
|
||||
)
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
@@ -231,7 +239,11 @@ class ConfigMixin:
|
||||
|
||||
@classmethod
|
||||
def load_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
cls,
|
||||
pretrained_model_name_or_path: Union[str, os.PathLike],
|
||||
return_unused_kwargs=False,
|
||||
return_commit_hash=False,
|
||||
**kwargs,
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
r"""
|
||||
Instantiate a Python class from a config dictionary
|
||||
@@ -271,6 +283,10 @@ class ConfigMixin:
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
||||
Whether unused keyword arguments of the config shall be returned.
|
||||
return_commit_hash (`bool`, *optional*, defaults to `False):
|
||||
Whether the commit_hash of the loaded configuration shall be returned.
|
||||
|
||||
<Tip>
|
||||
|
||||
@@ -295,8 +311,10 @@ class ConfigMixin:
|
||||
revision = kwargs.pop("revision", None)
|
||||
_ = kwargs.pop("mirror", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
user_agent = kwargs.pop("user_agent", {})
|
||||
|
||||
user_agent = {"file_type": "config"}
|
||||
user_agent = {**user_agent, "file_type": "config"}
|
||||
user_agent = http_user_agent(user_agent)
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
|
||||
@@ -336,7 +354,6 @@ class ConfigMixin:
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
||||
@@ -378,13 +395,23 @@ class ConfigMixin:
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(config_file)
|
||||
|
||||
commit_hash = extract_commit_hash(config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
||||
|
||||
if return_unused_kwargs:
|
||||
return config_dict, kwargs
|
||||
if not (return_unused_kwargs or return_commit_hash):
|
||||
return config_dict
|
||||
|
||||
return config_dict
|
||||
outputs = (config_dict,)
|
||||
|
||||
if return_unused_kwargs:
|
||||
outputs += (kwargs,)
|
||||
|
||||
if return_commit_hash:
|
||||
outputs += (commit_hash,)
|
||||
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _get_init_keys(cls):
|
||||
|
||||
Reference in New Issue
Block a user