diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index f5f81a4230..49b654591f 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -34,6 +34,7 @@ logger = logging.get_logger(__name__) LOADABLE_CLASSES = { "diffusers": { "ModelMixin": ["save_pretrained", "from_pretrained"], + "DiffusionPipeline": ["save_pretrained", "from_pretrained"], "GaussianDDPMScheduler": ["save_config", "from_config"], "ClassifierFreeGuidanceScheduler": ["save_config", "from_config"], "GlideDDIMScheduler": ["save_config", "from_config"], @@ -74,7 +75,7 @@ class DiffusionPipeline(ConfigMixin): # set models setattr(self, name, module) - register_dict = {"_module": self.__module__.split(".")[-1] + ".py"} + register_dict = {"_module": self.__module__.split(".")[-1]} self.register(**register_dict) def save_pretrained(self, save_directory: Union[str, os.PathLike]): @@ -133,19 +134,21 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) # 2. Get class name and module candidates to load custom models - class_name_ = config_dict["_class_name"] - module_candidate = config_dict["_module"] - module_candidate_name = module_candidate.replace(".py", "") + module_candidate_name = config_dict["_module"] + module_candidate = module_candidate_name + ".py" # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls else: + diffusers_module = importlib.import_module(cls.__module__.split(".")[0]) + pipeline_class = getattr(diffusers_module, config_dict["_class_name"]) + + # (TODO - we should allow to load custom pipelines # else we need to load the correct module from the Hub - class_name_ = config_dict["_class_name"] - module = module_candidate - pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) + # module = module_candidate + # pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)