diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 2d4803d356..bcce66b4c2 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -101,13 +101,15 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) + module_candidate = config_dict["_module"] + # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls else: # else we need to load the correct module from the Hub class_name_ = config_dict["_class_name"] - module = config_dict["_module"] + module = module_candidate pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -117,7 +119,7 @@ class DiffusionPipeline(ConfigMixin): for name, (library_name, class_name) in init_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] - if library_name == module: + if library_name == module_candidate: # TODO(Suraj) # for vq pass