diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index e836b765b0..06e2ab2e56 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -45,6 +45,10 @@ LOADABLE_CLASSES = { }, } +ALL_IMPORTABLE_CLASSES = {} +for library in LOADABLE_CLASSES: + ALL_IMPORTABLE_CLASSES.update(LOADABLE_CLASSES[library]) + class DiffusionPipeline(ConfigMixin): @@ -105,10 +109,8 @@ class DiffusionPipeline(ConfigMixin): Add docstrings """ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE) - force_download = kwargs.pop("force_download", False) resume_download = kwargs.pop("resume_download", False) proxies = kwargs.pop("proxies", None) - output_loading_info = kwargs.pop("output_loading_info", False) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) @@ -117,10 +119,8 @@ class DiffusionPipeline(ConfigMixin): cached_folder = snapshot_download( pretrained_model_name_or_path, cache_dir=cache_dir, - force_download=force_download, resume_download=resume_download, proxies=proxies, - output_loading_info=output_loading_info, local_files_only=local_files_only, use_auth_token=use_auth_token, ) @@ -147,20 +147,14 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {} - # get all importable classes to get the load method name for custom models/components - # here we enforce that custom models/components should always subclass from base classes in tansformers and diffusers - all_importable_classes = {} - for library in LOADABLE_CLASSES: - all_importable_classes.update(LOADABLE_CLASSES[library]) - for name, (library_name, class_name) in init_dict.items(): # if the model is not in diffusers or transformers, we need to load it from the hub # assumes that it's a subclass of ModelMixin if library_name == module_candidate_name: class_obj = get_class_from_dynamic_module(cached_folder, module, class_name, cached_folder) # since it's not from a library, we need to check class candidates for all importable classes - importable_classes = all_importable_classes - class_candidates = {c: class_obj for c in all_importable_classes} + importable_classes = ALL_IMPORTABLE_CLASSES + class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} else: library = importlib.import_module(library_name) class_obj = getattr(library, class_name)