diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 52acc7b986..ae049a8c0a 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -20,6 +20,9 @@ import torch class DDPM(DiffusionPipeline): + + modeling_file = "modeling_ddpm.py" + def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index 3ec9fc8575..d4e050681a 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -53,8 +53,11 @@ class DiffusionPipeline(ConfigMixin): # retrive class_name class_name = module.__class__.__name__ + register_dict = {name: (library, class_name)} + register_dict["_module"] = self.__module__ + # save model index config - self.register(**{name: (library, class_name)}) + self.register(**register_dict) # set models setattr(self, name, module) @@ -84,7 +87,10 @@ class DiffusionPipeline(ConfigMixin): def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): # use snapshot download here to get it working from from_pretrained cached_folder = snapshot_download(pretrained_model_name_or_path) - config_dict, _ = cls.get_config_dict(cached_folder) + config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) + + module = pipeline_kwargs["_module"] + # TODO(Suraj) - make from hub import work init_kwargs = {}