1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add first logic for from hub code download

This commit is contained in:
Patrick von Platen
2022-06-07 11:31:20 +02:00
parent e8ad2b75e7
commit 40dc888fca
2 changed files with 11 additions and 2 deletions

View File

@@ -20,6 +20,9 @@ import torch
class DDPM(DiffusionPipeline):
modeling_file = "modeling_ddpm.py"
def __init__(self, unet, noise_scheduler):
super().__init__()
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)

View File

@@ -53,8 +53,11 @@ class DiffusionPipeline(ConfigMixin):
# retrive class_name
class_name = module.__class__.__name__
register_dict = {name: (library, class_name)}
register_dict["_module"] = self.__module__
# save model index config
self.register(**{name: (library, class_name)})
self.register(**register_dict)
# set models
setattr(self, name, module)
@@ -84,7 +87,10 @@ class DiffusionPipeline(ConfigMixin):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
# use snapshot download here to get it working from from_pretrained
cached_folder = snapshot_download(pretrained_model_name_or_path)
config_dict, _ = cls.get_config_dict(cached_folder)
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
module = pipeline_kwargs["_module"]
# TODO(Suraj) - make from hub import work
init_kwargs = {}