diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 743a807510..77be534009 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -55,11 +55,20 @@ class DiffusionPipeline(ConfigMixin): config_name = "model_index.json" def register_modules(self, **kwargs): + # import it here to avoid circular import + from diffusers import pipelines + for name, module in kwargs.items(): + # check if the module is a pipeline module + is_pipeline_module = hasattr(pipelines, module.__module__.split(".")[-1]) + # retrive library library = module.__module__.split(".")[0] - # if library is not in LOADABLE_CLASSES, then it is a custom module - if library not in LOADABLE_CLASSES: + + # if library is not in LOADABLE_CLASSES, then it is a custom module. + # Or if it's a pipeline module, then the module is inside the pipeline + # so we set the library to module name. + if library not in LOADABLE_CLASSES or is_pipeline_module: library = module.__module__.split(".")[-1] # retrive class_name @@ -151,12 +160,22 @@ class DiffusionPipeline(ConfigMixin): init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) init_kwargs = {} - + + # import it here to avoid circular import + from diffusers import pipelines + # 4. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): - # if the model is not in diffusers or transformers, we need to load it from the hub - # assumes that it's a subclass of ModelMixin - if library_name == module_candidate_name: + is_pipeline_module = hasattr(pipelines, library_name) + # if the model is in a pipeline module, then we load it from the pipeline + if is_pipeline_module: + pipeline_module = getattr(pipelines, library_name) + class_obj = getattr(pipeline_module, class_name) + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} + elif library_name == module_candidate_name: + # if the model is not in diffusers or transformers, we need to load it from the hub + # assumes that it's a subclass of ModelMixin class_obj = get_class_from_dynamic_module(cached_folder, module_candidate, class_name, cached_folder) # since it's not from a library, we need to check class candidates for all importable classes importable_classes = ALL_IMPORTABLE_CLASSES