mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge remote-tracking branch 'origin/main'
This commit is contained in:
@@ -111,7 +111,8 @@ class DiffusionPipeline(ConfigMixin):
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
|
||||
|
||||
# 1. Download the checkpoints and configs
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
cached_folder = snapshot_download(
|
||||
@@ -127,11 +128,12 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
config_dict = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = config_dict["_module"]
|
||||
# 2. Get class name and module candidates to load custom models
|
||||
class_name_ = config_dict["_class_name"]
|
||||
module_candidate = config_dict["_module"]
|
||||
module_candidate_name = module_candidate.replace(".py", "")
|
||||
|
||||
# 3. Load the pipeline class, if using custom module then load it from the hub
|
||||
# if we load from explicit class, let's use it
|
||||
if cls != DiffusionPipeline:
|
||||
pipeline_class = cls
|
||||
@@ -145,6 +147,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
# 4. Load each module in the pipeline
|
||||
for name, (library_name, class_name) in init_dict.items():
|
||||
# if the model is not in diffusers or transformers, we need to load it from the hub
|
||||
# assumes that it's a subclass of ModelMixin
|
||||
@@ -154,6 +157,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
|
||||
else:
|
||||
# else we just import it from the library.
|
||||
library = importlib.import_module(library_name)
|
||||
class_obj = getattr(library, class_name)
|
||||
importable_classes = LOADABLE_CLASSES[library_name]
|
||||
@@ -166,12 +170,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
# check if the module is in a subdirectory
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||||
else:
|
||||
# else load from the root directory
|
||||
loaded_sub_model = load_method(cached_folder)
|
||||
|
||||
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
||||
|
||||
# 5. Instantiate the pipeline
|
||||
model = pipeline_class(**init_kwargs)
|
||||
return model
|
||||
|
||||
Reference in New Issue
Block a user