mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Allow custom pipeline loading (#3504)
This commit is contained in:
committed by
GitHub
parent
b134f6a8b6
commit
d4197bf4d7
@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
|
||||
library = module.__module__.split(".")[0]
|
||||
|
||||
# check if the module is a pipeline module
|
||||
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
|
||||
module_path_items = module.__module__.split(".")
|
||||
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
|
||||
|
||||
path = module.__module__.split(".")
|
||||
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
|
||||
|
||||
# if library is not in LOADABLE_CLASSES, then it is a custom module.
|
||||
# Or if it's a pipeline module, then the module is inside the pipeline
|
||||
# folder so we set the library to module name.
|
||||
if library not in LOADABLE_CLASSES or is_pipeline_module:
|
||||
if is_pipeline_module:
|
||||
library = pipeline_dir
|
||||
elif library not in LOADABLE_CLASSES:
|
||||
library = module.__module__
|
||||
|
||||
# retrieve class_name
|
||||
class_name = module.__class__.__name__
|
||||
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
# 6.2 Define all importable classes
|
||||
is_pipeline_module = hasattr(pipelines, library_name)
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
|
||||
importable_classes = ALL_IMPORTABLE_CLASSES
|
||||
loaded_sub_model = None
|
||||
|
||||
# 6.3 Use passed sub model or load class_name from library_name
|
||||
|
||||
Reference in New Issue
Block a user