diff --git a/models/vision/ddpm/example.py b/models/vision/ddpm/example.py new file mode 100755 index 0000000000..ec339c4cdf --- /dev/null +++ b/models/vision/ddpm/example.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python3 +from diffusers import UNetModel, GaussianDiffusion +from modeling_ddpm import DDPM +import tempfile + +unet = UNetModel.from_pretrained("fusing/ddpm_dummy") +sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") + +# compose Diffusion Pipeline +ddpm = DDPM(unet, sampler) +# generate / sample +image = ddpm() +print(image) + + +# save and load with 0 extra code (handled by general `DiffusionPipeline` class) +with tempfile.TemporaryDirectory() as tmpdirname: + ddpm.save_pretrained(tmpdirname) + print("Model saved") + ddpm_new = DDPM.from_pretrained(tmpdirname) + print("Model loaded") + print(ddpm_new) diff --git a/models/vision/ddpm/modeling_ddpm.py b/models/vision/ddpm/modeling_ddpm.py index e69de29bb2..3525ec30c0 100644 --- a/models/vision/ddpm/modeling_ddpm.py +++ b/models/vision/ddpm/modeling_ddpm.py @@ -0,0 +1,27 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + + +from diffusers import DiffusionPipeline + + +class DDPM(DiffusionPipeline): + + def __init__(self, unet, gaussian_sampler): + super().__init__(unet=unet, gaussian_sampler=gaussian_sampler) + + def __call__(self, batch_size=1): + image = self.gaussian_sampler.sample(self.unet, batch_size=batch_size) + return image diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index 41f680261c..135d49c83e 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -6,3 +6,6 @@ __version__ = "0.0.1" from .models.unet import UNetModel from .samplers.gaussian import GaussianDiffusion + +from .pipeline_utils import DiffusionPipeline +from .modeling_utils import PreTrainedModel diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index bb30751c8e..164156437e 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -91,8 +91,8 @@ class Config: logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_config( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -198,6 +198,14 @@ class Config: f"Values will be initialized to default values." ) + return config_dict, unused_kwargs + + @classmethod + def from_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs + ): + config_dict, unused_kwargs = cls.get_config_dict(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs) + model = cls(**config_dict) if return_unused_kwargs: diff --git a/src/diffusers/pipeline_utils.py b/src/diffusers/pipeline_utils.py new file mode 100644 index 0000000000..2e4c88b785 --- /dev/null +++ b/src/diffusers/pipeline_utils.py @@ -0,0 +1,108 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Optional, Union +import importlib + +from .configuration_utils import Config + +# CHANGE to diffusers.utils +from transformers.utils import logging + + +INDEX_FILE = "diffusion_model.pt" + + +logger = logging.get_logger(__name__) + + +LOADABLE_CLASSES = { + "diffusers": { + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + "GaussianDiffusion": ["save_config", "from_config"], + }, + "transformers": { + "PreTrainedModel": ["save_pretrained", "from_pretrained"], + }, +} + + +class DiffusionPipeline(Config): + + config_name = "model_index.json" + + def __init__(self, **kwargs): + for name, module in kwargs.items(): + # retrive library + library = module.__module__.split(".")[0] + # retrive class_name + class_name = module.__class__.__name__ + + # save model index config + self.register(**{name: (library, class_name)}) + + # set models + setattr(self, name, module) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + self.save_config(save_directory) + + model_index_dict = self._dict_to_save + model_index_dict.pop("_class_name") + + for name, (library_name, class_name) in self._dict_to_save.items(): + importable_classes = LOADABLE_CLASSES[library_name] + + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + save_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + save_method_name = importable_classes[class_name][0] + + save_method = getattr(getattr(self, name), save_method_name) + save_method(os.path.join(save_directory, name)) + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): + # use snapshot download here to get it working from from_pretrained + config_dict, _ = cls.get_config_dict(pretrained_model_name_or_path) + + init_kwargs = {} + + for name, (library_name, class_name) in config_dict.items(): + importable_classes = LOADABLE_CLASSES[library_name] + + library = importlib.import_module(library_name) + class_obj = getattr(library, class_name) + class_candidates = {c: getattr(library, c) for c in importable_classes.keys()} + + load_method_name = None + for class_name, class_candidate in class_candidates.items(): + if issubclass(class_obj, class_candidate): + load_method_name = importable_classes[class_name][1] + + load_method = getattr(class_obj, load_method_name) + + loaded_sub_model = load_method(os.path.join(pretrained_model_name_or_path, name)) + + init_kwargs[name] = loaded_sub_model + + model = cls(**init_kwargs) + return model