From 86064df7b556e1ffb2a37c0d03ec5d10ecba940a Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Wed, 8 Jun 2022 09:14:50 +0000 Subject: [PATCH] fix --- src/diffusers/pipeline_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 2d4803d356..bcce66b4c2 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -101,13 +101,15 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) + module_candidate = config_dict["_module"] + # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls else: # else we need to load the correct module from the Hub class_name_ = config_dict["_class_name"] - module = config_dict["_module"] + module = module_candidate pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder) init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs) @@ -117,7 +119,7 @@ class DiffusionPipeline(ConfigMixin): for name, (library_name, class_name) in init_dict.items(): importable_classes = LOADABLE_CLASSES[library_name] - if library_name == module: + if library_name == module_candidate: # TODO(Suraj) # for vq pass