1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

Allow custom pipeline loading (#3504)

This commit is contained in:
Patrick von Platen
2023-05-23 14:20:55 +02:00
committed by GitHub
parent b134f6a8b6
commit d4197bf4d7
2 changed files with 34 additions and 3 deletions

View File

@@ -491,15 +491,19 @@ class DiffusionPipeline(ConfigMixin):
library = module.__module__.split(".")[0]
# check if the module is a pipeline module
pipeline_dir = module.__module__.split(".")[-2] if len(module.__module__.split(".")) > 2 else None
module_path_items = module.__module__.split(".")
pipeline_dir = module_path_items[-2] if len(module_path_items) > 2 else None
path = module.__module__.split(".")
is_pipeline_module = pipeline_dir in path and hasattr(pipelines, pipeline_dir)
# if library is not in LOADABLE_CLASSES, then it is a custom module.
# Or if it's a pipeline module, then the module is inside the pipeline
# folder so we set the library to module name.
if library not in LOADABLE_CLASSES or is_pipeline_module:
if is_pipeline_module:
library = pipeline_dir
elif library not in LOADABLE_CLASSES:
library = module.__module__
# retrieve class_name
class_name = module.__class__.__name__
@@ -1039,7 +1043,7 @@ class DiffusionPipeline(ConfigMixin):
# 6.2 Define all importable classes
is_pipeline_module = hasattr(pipelines, library_name)
importable_classes = ALL_IMPORTABLE_CLASSES if is_pipeline_module else LOADABLE_CLASSES[library_name]
importable_classes = ALL_IMPORTABLE_CLASSES
loaded_sub_model = None
# 6.3 Use passed sub model or load class_name from library_name