1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

add some comments

This commit is contained in:
patil-suraj
2022-06-09 14:12:22 +02:00
parent 2234877e01
commit 758f9d2201

View File

@@ -113,7 +113,8 @@ class DiffusionPipeline(ConfigMixin):
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", False)
use_auth_token = kwargs.pop("use_auth_token", None)
# 1. Download the checkpoints and configs
# use snapshot download here to get it working from from_pretrained
if not os.path.isdir(pretrained_model_name_or_path):
cached_folder = snapshot_download(
@@ -129,11 +130,12 @@ class DiffusionPipeline(ConfigMixin):
config_dict = cls.get_config_dict(cached_folder)
module = config_dict["_module"]
# 2. Get class name and module candidates to load custom models
class_name_ = config_dict["_class_name"]
module_candidate = config_dict["_module"]
module_candidate_name = module_candidate.replace(".py", "")
# 3. Load the pipeline class, if using custom module then load it from the hub
# if we load from explicit class, let's use it
if cls != DiffusionPipeline:
pipeline_class = cls
@@ -147,6 +149,7 @@ class DiffusionPipeline(ConfigMixin):
init_kwargs = {}
# 4. Load each module in the pipeline
for name, (library_name, class_name) in init_dict.items():
# if the model is not in diffusers or transformers, we need to load it from the hub
# assumes that it's a subclass of ModelMixin
@@ -156,6 +159,7 @@ class DiffusionPipeline(ConfigMixin):
importable_classes = ALL_IMPORTABLE_CLASSES
class_candidates = {c: class_obj for c in ALL_IMPORTABLE_CLASSES.keys()}
else:
# else we just import it from the library.
library = importlib.import_module(library_name)
class_obj = getattr(library, class_name)
importable_classes = LOADABLE_CLASSES[library_name]
@@ -168,12 +172,15 @@ class DiffusionPipeline(ConfigMixin):
load_method = getattr(class_obj, load_method_name)
# check if the module is in a subdirectory
if os.path.isdir(os.path.join(cached_folder, name)):
loaded_sub_model = load_method(os.path.join(cached_folder, name))
else:
# else load from the root directory
loaded_sub_model = load_method(cached_folder)
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
# 5. Instantiate the pipeline
model = pipeline_class(**init_kwargs)
return model