mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' of github.com:huggingface/diffusers
This commit is contained in:
@@ -45,6 +45,10 @@ LOADABLE_CLASSES = {
|
||||
},
|
||||
}
|
||||
|
||||
ALL_IMPORTABLE_CLASSES = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library])
|
||||
|
||||
|
||||
class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
@@ -105,10 +109,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
Add docstrings
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
@@ -117,10 +119,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
cached_folder = snapshot_download(
|
||||
pretrained_model_name_or_path,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
output_loading_info=output_loading_info,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
)
|
||||
@@ -147,20 +147,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
# get all importable classes to get the load method name for custom models/components
|
||||
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
|
||||
all_importable_classes = {}
|
||||
for library in LOADABLE_CLASSES:
|
||||
all_importable_classes.update(LOADABLE_CLASSES[library])
|
||||
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||
# assumes that it's a subclass of ModelMixin
|
||||
if library_name == module_candidate_name:
|
||||
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
|
||||
# since it's not from a library, we need to check class candidates for all importable classes
|
||||
importable_classes = all_importable_classes
|
||||
class_candidates = {c: class_obj for c in all_importable_classes}
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||
else:
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
|
||||
Reference in New Issue
Block a user