From 6b66999e7552ea3f5333d2a88e13d3f3a196d069 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Jun 2022 12:40:23 +0200 Subject: [PATCH 1/2] make ALL_IMPORTABLE_CLASSES static --- src/diffusers/pipeline_utils.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index cd69b9cf70..3fbf95bbc3 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -46,6 +46,10 @@ LOADABLE_CLASSES = { }, } +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + class DiffusionPipeline(ConfigMixin): @@ -125,12 +129,6 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {} - # get all importable classes to get the load method name for custom models/components - # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers - all_importable_classes = {} - for library in LOADABLE_CLASSES: - all_importable_classes.update(LOADABLE_CLASSES[library]) - for name, (library_name, class_name) in init_dict.items(): # if the model is not in diffusers or transformers, we need to load it from the hub @@ -138,8 +136,8 @@ class DiffusionPipeline(ConfigMixin): if library_name == module_candidate_name: class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder) # since it's not from a library, we need to check class candidates for all importable classes - importable_classes = all_importable_classes - class_candidates = {c: class_obj for c in all_importable_classes} + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} else: library = importlib.import_module(library_name) class_obj = getattr(library, class_name) From 2fa1d64841fdd0290d6abf5f0c4129643c089441 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Thu, 9 Jun 2022 12:46:02 +0200 Subject: [PATCH 2/2] remove incorrect args --- src/diffusers/pipeline_utils.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 6cc19ae1df..06e2ab2e56 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -109,10 +109,8 @@ class DiffusionPipeline(ConfigMixin): Add docstrings """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -121,10 +119,8 @@ class DiffusionPipeline(ConfigMixin): cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, - force_download=force_download, resume_download=resume_download, proxies=proxies, - output_loading_info=output_loading_info, local_files_only=local_files_only, use_auth_token=use_auth_token, )