From dd4cd081db39d6769060bb48d0137b832789f015 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Tue, 7 Jun 2022 11:54:50 +0000 Subject: [PATCH] fix naming --- models/vision/ddpm/modeling_ddpm.py | 2 +- src/diffusers/pipeline_utils.py | 10 +++++++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index 4a3f0b24b7..ae049a8c0a 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline): modeling_file = "modeling_ddpm.py" - def __init__(self, unet, noise_scheduler, vqvae): + def __init__(self, unet, noise_scheduler): super().__init__() self.register_modules(unet=unet, noise_scheduler=noise_scheduler) diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py index a9f5c6e2f3..2749ad68b7 100644 --- a/src/diffusers/pipeline_utils.py +++ b/src/diffusers/pipeline_utils.py @@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin): @classmethod def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): # use snapshot download here to get it working from from_pretrained - cached_folder = snapshot_download(pretrained_model_name_or_path) + if not os.path.isdir(pretrained_model_name_or_path): + cached_folder = snapshot_download(pretrained_model_name_or_path) + else: + cached_folder = pretrained_model_name_or_path + config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder) - module = pipeline_kwargs["_module"] + module = pipeline_kwargs.pop("_module", None) # TODO(Suraj) - make from hub import work # Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work # Add Sylvains code from transformers @@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin): load_method = getattr(class_obj, load_method_name) - if os.path.dir(os.path.join(cached_folder, name)): + if os.path.isdir(os.path.join(cached_folder, name)): loaded_sub_model = load_method(os.path.join(cached_folder, name)) else: loaded_sub_model = load_method(cached_folder)