diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 06e2ab2e56..7389b7f500 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -113,7 +113,8 @@ class DiffusionPipeline(ConfigMixin): proxies = kwargs.pop("proxies", None) local_files_only = kwargs.pop("local_files_only", False) use_auth_token = kwargs.pop("use_auth_token", None) - + + # 1. Download the checkpoints and configs # use snapshot download here to get it working from from_pretrained if not os.path.isdir(pretrained_model_name_or_path): cached_folder = snapshot_download( @@ -129,11 +130,12 @@ class DiffusionPipeline(ConfigMixin): config_dict = cls.get_config_dict(cached_folder) - module = config_dict["_module"] + # 2. Get class name and module candidates to load custom models class_name_ = config_dict["_class_name"] module_candidate = config_dict["_module"] module_candidate_name = module_candidate.replace(".py", "") + # 3. Load the pipeline class, if using custom module then load it from the hub # if we load from explicit class, let's use it if cls != DiffusionPipeline: pipeline_class = cls @@ -147,6 +149,7 @@ class DiffusionPipeline(ConfigMixin): init_kwargs = {} + # 4. Load each module in the pipeline for name, (library_name, class_name) in init_dict.items(): # if the model is not in diffusers or transformers, we need to load it from the hub # assumes that it's a subclass of ModelMixin @@ -156,6 +159,7 @@ class DiffusionPipeline(ConfigMixin): importable_classes = ALL_IMPORTABLE_CLASSES class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()} else: + # else we just import it from the library. library = importlib.import_module(library_name) class_obj = getattr(library, class_name) importable_classes = LOADABLE_CLASSES[library_name] @@ -168,12 +172,15 @@ class DiffusionPipeline(ConfigMixin): load_method = getattr(class_obj, load_method_name) + # check if the module is in a subdirectory if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name)) else: + # else load from the root directory loaded_sub_model = load_method(cached_folder) init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...) + # 5. Instantiate the pipeline model = pipeline_class(**init_kwargs) return model