1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

load pipeline from source

This commit is contained in:
Patrick von Platen
2022-06-12 18:13:23 +00:00
parent e83ff11f57
commit bda825f910

View File

@@ -34,6 +34,7 @@ logger = logging.get_logger(__name__)
LOADABLE_CLASSES = {
"diffusers": {
"ModelMixin": ["save_pretrained", "from_pretrained"],
"DiffusionPipeline": ["save_pretrained", "from_pretrained"],
"GaussianDDPMScheduler": ["save_config", "from_config"],
"ClassifierFreeGuidanceScheduler": ["save_config", "from_config"],
"GlideDDIMScheduler": ["save_config", "from_config"],
@@ -74,7 +75,7 @@ class DiffusionPipeline(ConfigMixin):
# set models
setattr(self, name, module)
register_dict = {"_module": self.__module__.split(".")[-1] + ".py"}
register_dict = {"_module": self.__module__.split(".")[-1]}
self.register(**register_dict)
def save_pretrained(self, save_directory: Union[str, os.PathLike]):
@@ -133,19 +134,21 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
# 2. Get class name and module candidates to load custom models
class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
module_candidate_name = config_dict["_module"]
module_candidate = module_candidate_name + ".py"
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
else:
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
# (TODO - we should allow to load custom pipelines
# else we need to load the correct module from the Hub
class_name_ = config_dict["_class_name"]
module = module_candidate
pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
# module = module_candidate
# pipeline_class = get_class_from_dynamic_module(cached_folder, module, class_name_, cached_folder)
init_dict, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)