mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add first logic for from hub code download
This commit is contained in:
@@ -20,6 +20,9 @@ import torch
|
||||
|
||||
|
||||
class DDPM(DiffusionPipeline):
|
||||
|
||||
modeling_file = "modeling_ddpm.py"
|
||||
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
@@ -53,8 +53,11 @@ class DiffusionPipeline(ConfigMixin):
|
||||
# retrive class_name
|
||||
class_name = module.__class__.__name__
|
||||
|
||||
register_dict = {name: (library, class_name)}
|
||||
register_dict["_module"] = self.__module__
|
||||
|
||||
# save model index config
|
||||
self.register(**{name: (library, class_name)})
|
||||
self.register(**register_dict)
|
||||
|
||||
# set models
|
||||
setattr(self, name, module)
|
||||
@@ -84,7 +87,10 @@ class DiffusionPipeline(ConfigMixin):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
config_dict, _ = cls.get_config_dict(cached_folder)
|
||||
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = pipeline_kwargs["_module"]
|
||||
# TODO(Suraj) - make from hub import work
|
||||
|
||||
init_kwargs = {}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user