mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix naming
This commit is contained in:
@@ -23,7 +23,7 @@ class DDPM(DiffusionPipeline):
|
||||
|
||||
modeling_file = "modeling_ddpm.py"
|
||||
|
||||
def __init__(self, unet, noise_scheduler, vqvae):
|
||||
def __init__(self, unet, noise_scheduler):
|
||||
super().__init__()
|
||||
self.register_modules(unet=unet, noise_scheduler=noise_scheduler)
|
||||
|
||||
|
||||
@@ -90,10 +90,14 @@ class DiffusionPipeline(ConfigMixin):
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
# use snapshot download here to get it working from from_pretrained
|
||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
if not os.path.isdir(pretrained_model_name_or_path):
|
||||
cached_folder = snapshot_download(pretrained_model_name_or_path)
|
||||
else:
|
||||
cached_folder = pretrained_model_name_or_path
|
||||
|
||||
config_dict, pipeline_kwargs = cls.get_config_dict(cached_folder)
|
||||
|
||||
module = pipeline_kwargs["_module"]
|
||||
module = pipeline_kwargs.pop("_module", None)
|
||||
# TODO(Suraj) - make from hub import work
|
||||
# Make `ddpm = DiffusionPipeline.from_pretrained("fusing/ddpm-lsun-bedroom-pipe")` work
|
||||
# Add Sylvains code from transformers
|
||||
@@ -118,7 +122,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||
|
||||
load_method = getattr(class_obj, load_method_name)
|
||||
|
||||
if os.path.dir(os.path.join(cached_folder, name)):
|
||||
if os.path.isdir(os.path.join(cached_folder, name)):
|
||||
loaded_sub_model = load_method(os.path.join(cached_folder, name))
|
||||
else:
|
||||
loaded_sub_model = load_method(cached_folder)
|
||||
|
||||
Reference in New Issue
Block a user