From 02cdd68331d423177d62351bfde40659da626318 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Jun 2022 12:21:56 +0200 Subject: [PATCH] genric logic to get load method for custom model --- src/diffusers/pipeline_utils.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index d1d93a9a78..cd69b9cf70 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -125,24 +125,31 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {} + # get all importable classes to get the load method name for custom models/components + # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers + all_importable_classes = {} + for library in LOADABLE_CLASSES: + all_importable_classes.update(LOADABLE_CLASSES[library]) + for name, (library_name, class_name) in init_dict.items(): # if the model is not in diffusers or transformers, we need to load it from the hub # assumes that it's a subclass of ModelMixin if library_name == module_candidate_name: class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder) - load_method_name = "from_pretrained" + # since it's not from a library, we need to check class candidates for all importable classes + importable_classes = all_importable_classes + class_candidates = {c: class_obj for c in all_importable_classes} else: - importable_classes = LOADABLE_CLASSES[library_name] - library = importlib.import_module(library_name) class_obj = getattr(library, class_name) + importable_classes = LOADABLE_CLASSES[library_name] class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} - load_method_name = None - for class_name, class_candidate in class_candidates.items(): - if issubclass(class_obj, class_candidate): - load_method_name = importable_classes[class_name][1] + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] load_method = getattr(class_obj, load_method_name)