1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

genric logic to get load method for custom model

This commit is contained in:
patil-suraj
2022-06-09 12:21:56 +02:00
parent 74d2da9950
commit 02cdd68331

View File

@@ -125,24 +125,31 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {}
# get all importable classes to get the load method name for custom models/components
# here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers
all_importable_classes = {}
for library in LOADABLE_CLASSES:
all_importable_classes.update(LOADABLE_CLASSES[library])
for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
if library_name == module_candidate_name:
class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder)
load_method_name = "from_pretrained"
# since it's not from a library, we need to check class candidates for all importable classes
importable_classes = all_importable_classes
class_candidates = {c: class_obj for c in all_importable_classes}
else:
importable_classes = LOADABLE_CLASSES[library_name]
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
class_candidates = {c: getattr(library, c) for c in importable_classes.keys()}
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method_name = None
for class_name, class_candidate in class_candidates.items():
if issubclass(class_obj, class_candidate):
load_method_name = importable_classes[class_name][1]
load_method = getattr(class_obj, load_method_name)