diff --git a/md5sum.saved b/md5sum.saved
deleted file mode 100644
index 2dc63471f2..0000000000
--- a/md5sum.saved
+++ /dev/null
@@ -1 +0,0 @@
-ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py
diff --git a/models/ddpm/run_ddpm.py b/models/ddpm/run_ddpm.py
index 668c2c301e..6bf131f7b5 100755
--- a/models/ddpm/run_ddpm.py
+++ b/models/ddpm/run_ddpm.py
@@ -1,19 +1,10 @@
#!/usr/bin/env python3
import torch
-from diffusers import GaussianDiffusion, UNetConfig, UNetModel
+from diffusers import GaussianDiffusion, UNetModel
-config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8))
-model = UNetModel(config)
-print(model.config)
-
-model.save_pretrained("/home/patrick/diffusion_example")
-
-import ipdb
-
-
-ipdb.set_trace()
+model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
diff --git a/src/diffusers.egg-info/PKG-INFO b/src/diffusers.egg-info/PKG-INFO
deleted file mode 100644
index 584c731df7..0000000000
--- a/src/diffusers.egg-info/PKG-INFO
+++ /dev/null
@@ -1,31 +0,0 @@
-Metadata-Version: 2.1
-Name: diffusers
-Version: 0.0.1
-Summary: Diffusers
-Home-page: https://github.com/huggingface/diffusers
-Author: The HuggingFace team
-Author-email: patrick@huggingface.co
-License: Apache
-Keywords: deep learning
-Classifier: Development Status :: 5 - Production/Stable
-Classifier: Intended Audience :: Developers
-Classifier: Intended Audience :: Education
-Classifier: Intended Audience :: Science/Research
-Classifier: License :: OSI Approved :: Apache Software License
-Classifier: Operating System :: OS Independent
-Classifier: Programming Language :: Python :: 3
-Classifier: Programming Language :: Python :: 3.6
-Classifier: Programming Language :: Python :: 3.7
-Classifier: Programming Language :: Python :: 3.8
-Classifier: Programming Language :: Python :: 3.9
-Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
-Requires-Python: >=3.6.0
-Description-Content-Type: text/markdown
-Provides-Extra: quality
-Provides-Extra: docs
-Provides-Extra: test
-Provides-Extra: dev
-Provides-Extra: sagemaker
-License-File: LICENSE
-
-Super cool library about diffusion models
diff --git a/src/diffusers.egg-info/SOURCES.txt b/src/diffusers.egg-info/SOURCES.txt
deleted file mode 100644
index 2a49d211f3..0000000000
--- a/src/diffusers.egg-info/SOURCES.txt
+++ /dev/null
@@ -1,14 +0,0 @@
-LICENSE
-README.md
-pyproject.toml
-setup.cfg
-setup.py
-src/diffusers/__init__.py
-src/diffusers.egg-info/PKG-INFO
-src/diffusers.egg-info/SOURCES.txt
-src/diffusers.egg-info/dependency_links.txt
-src/diffusers.egg-info/requires.txt
-src/diffusers.egg-info/top_level.txt
-src/diffusers/models/__init__.py
-src/diffusers/models/unet.py
-src/diffusers/samplers/__init__.py
\ No newline at end of file
diff --git a/src/diffusers.egg-info/dependency_links.txt b/src/diffusers.egg-info/dependency_links.txt
deleted file mode 100644
index 8b13789179..0000000000
--- a/src/diffusers.egg-info/dependency_links.txt
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/src/diffusers.egg-info/requires.txt b/src/diffusers.egg-info/requires.txt
deleted file mode 100644
index 9d13868f00..0000000000
--- a/src/diffusers.egg-info/requires.txt
+++ /dev/null
@@ -1,31 +0,0 @@
-numpy>=1.17
-packaging>=20.0
-pyyaml
-torch>=1.4.0
-
-[dev]
-black~=22.0
-isort>=5.5.4
-flake8>=3.8.3
-pytest
-pytest-xdist
-pytest-subtests
-datasets
-transformers
-
-[docs]
-
-[quality]
-black~=22.0
-isort>=5.5.4
-flake8>=3.8.3
-
-[sagemaker]
-sagemaker
-
-[test]
-pytest
-pytest-xdist
-pytest-subtests
-datasets
-transformers
diff --git a/src/diffusers.egg-info/top_level.txt b/src/diffusers.egg-info/top_level.txt
deleted file mode 100644
index 6033efb6db..0000000000
--- a/src/diffusers.egg-info/top_level.txt
+++ /dev/null
@@ -1 +0,0 @@
-diffusers
diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py
index ad66a4be34..41f680261c 100644
--- a/src/diffusers/__init__.py
+++ b/src/diffusers/__init__.py
@@ -4,5 +4,5 @@
__version__ = "0.0.1"
-from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else
-from .models.unet import UNetConfig, UNetModel
+from .models.unet import UNetModel
+from .samplers.gaussian import GaussianDiffusion
diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py
old mode 100755
new mode 100644
index 1c25c463b5..bb30751c8e
--- a/src/diffusers/configuration_utils.py
+++ b/src/diffusers/configuration_utils.py
@@ -20,11 +20,11 @@ import copy
import json
import os
import re
+import inspect
from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from transformers.utils import (
- CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
@@ -44,33 +44,34 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
-class PretrainedConfig:
+class Config:
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
- model_type: str = ""
+ config_name = None
- def __init__(self, **kwargs):
- # Name or path to the pretrained checkpoint
- self._name_or_path = str(kwargs.pop("name_or_path", ""))
+ def register(self, **kwargs):
+ if self.config_name is None:
+ raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
+ kwargs["_class_name"] = self.__class__.__name__
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
- # Drop the diffusers version info
- self.diffusers_version = kwargs.pop("diffusers_version", None)
+ if not hasattr(self, "_dict_to_save"):
+ self._dict_to_save = {}
- @property
- def name_or_path(self) -> str:
- return getattr(self, "_name_or_path", None)
+ self._dict_to_save.update(kwargs)
- @name_or_path.setter
- def name_or_path(self, value):
- self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
-
- def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
+ def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
- [`~PretrainedConfig.from_pretrained`] class method.
+ [`~Config.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
@@ -83,119 +84,15 @@ class PretrainedConfig:
os.makedirs(save_directory, exist_ok=True)
- # If we save using the predefined names, we can load using `from_pretrained`
- output_config_file = os.path.join(save_directory, CONFIG_NAME)
+ # If we save using the predefined names, we can load using `from_config`
+ output_config_file = os.path.join(save_directory, self.config_name)
- self.to_json_file(output_config_file, use_diff=True)
+ self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
- r"""
- Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
-
- Args:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- This can be either:
-
- - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
- huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
- namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- - a path to a *directory* containing a configuration file saved using the
- [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
- - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
- cache_dir (`str` or `os.PathLike`, *optional*):
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
- standard cache should not be used.
- force_download (`bool`, *optional*, defaults to `False`):
- Whether or not to force to (re-)download the configuration files and override the cached versions if
- they exist.
- resume_download (`bool`, *optional*, defaults to `False`):
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file
- exists.
- proxies (`Dict[str, str]`, *optional*):
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
- 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
- use_auth_token (`str` or *bool*, *optional*):
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
- when running `diffusers-cli login` (stored in `~/.huggingface`).
- revision (`str`, *optional*, defaults to `"main"`):
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
- identifier allowed by git.
- return_unused_kwargs (`bool`, *optional*, defaults to `False`):
- If `False`, then this function returns just the final configuration object.
-
- If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
- dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
- part of `kwargs` which has not been used to update `config` and is otherwise ignored.
- kwargs (`Dict[str, Any]`, *optional*):
- The values in kwargs of any keys which are configuration attributes will be used to override the loaded
- values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
- by the `return_unused_kwargs` keyword parameter.
-
-
-
- Passing `use_auth_token=True` is required when you want to use a private model.
-
-
-
- Returns:
- [`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
-
- Examples:
-
- ```python
- # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
- # derived class: BertConfig
- config = BertConfig.from_pretrained(
- "bert-base-uncased"
- ) # Download configuration from huggingface.co and cache.
- config = BertConfig.from_pretrained(
- "./test/saved_model/"
- ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
- config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
- config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
- assert config.output_attentions == True
- config, unused_kwargs = BertConfig.from_pretrained(
- "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
- )
- assert config.output_attentions == True
- assert unused_kwargs == {"foo": False}
- ```"""
- config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
- if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
- logger.warning(
- f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
- f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
- )
-
- return cls.from_dict(config_dict, **kwargs)
-
- @classmethod
- def get_config_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
- ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
- """
- From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
- [`PretrainedConfig`] using `from_dict`.
-
- Parameters:
- pretrained_model_name_or_path (`str` or `os.PathLike`):
- The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
-
- Returns:
- `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
-
- """
- # Get config dict associated with the base config file
- config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
-
- return config_dict, kwargs
-
- @classmethod
- def _get_config_dict(
- cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ def from_config(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -215,7 +112,7 @@ class PretrainedConfig:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
- configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
+ configuration_file = cls.config_name
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
@@ -286,58 +183,27 @@ class PretrainedConfig:
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
- return config_dict, kwargs
+ expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
+ expected_keys.remove("self")
- @classmethod
- def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
- """
- Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
+ passed_keys = set(config_dict.keys())
- Args:
- config_dict (`Dict[str, Any]`):
- Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
- retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
- kwargs (`Dict[str, Any]`):
- Additional parameters from which to initialize the configuration object.
+ unused_kwargs = kwargs
+ for key in passed_keys - expected_keys:
+ unused_kwargs[key] = config_dict.pop(key)
- Returns:
- [`PretrainedConfig`]: The configuration object instantiated from those parameters.
- """
- return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
- # Those arguments may be passed along for our internal telemetry.
- # We remove them so they don't appear in `return_unused_kwargs`.
+ if len(expected_keys - passed_keys) > 0:
+ logger.warn(
+ f"{expected_keys - passed_keys} was not found in config. "
+ f"Values will be initialized to default values."
+ )
- config = cls(**config_dict)
+ model = cls(**config_dict)
- to_remove = []
- for key, value in kwargs.items():
- if hasattr(config, key):
- setattr(config, key, value)
- to_remove.append(key)
- for key in to_remove:
- kwargs.pop(key, None)
-
- logger.info(f"Model config {config}")
if return_unused_kwargs:
- return config, kwargs
+ return model, unused_kwargs
else:
- return config
-
- @classmethod
- def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
- """
- Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
-
- Args:
- json_file (`str` or `os.PathLike`):
- Path to the JSON file containing the parameters.
-
- Returns:
- [`PretrainedConfig`]: The configuration object instantiated from that JSON file.
-
- """
- config_dict = cls._dict_from_json_file(json_file)
- return cls(**config_dict)
+ return model
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
@@ -351,38 +217,6 @@ class PretrainedConfig:
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
- def to_diff_dict(self) -> Dict[str, Any]:
- """
- Removes all attributes from config which correspond to the default config attributes for better readability and
- serializes to a Python dictionary.
-
- Returns:
- `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
- """
- config_dict = self.to_dict()
-
- # get the default config dict
- default_config_dict = PretrainedConfig().to_dict()
-
- # get class specific config dict
- class_config_dict = self.__class__().to_dict()
-
- serializable_config_dict = {}
-
- # only serialize values that differ from the default config
- for key, value in config_dict.items():
- if (
- key not in default_config_dict
- or key == "diffusers_version"
- or value != default_config_dict[key]
- or (key in class_config_dict and value != class_config_dict[key])
- ):
- serializable_config_dict[key] = value
-
- self.dict_torch_dtype_to_str(serializable_config_dict)
-
- return serializable_config_dict
-
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
@@ -391,106 +225,29 @@ class PretrainedConfig:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
- if hasattr(self.__class__, "model_type"):
- output["model_type"] = self.__class__.model_type
- if "_auto_class" in output:
- del output["_auto_class"]
- # Transformers version when serializing the model
+ # Diffusion version when serializing the model
output["diffusers_version"] = __version__
- self.dict_torch_dtype_to_str(output)
-
return output
- def to_json_string(self, use_diff: bool = True) -> str:
+ def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
- Args:
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
- is serialized to JSON string.
-
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
- if use_diff is True:
- config_dict = self.to_diff_dict()
- else:
- config_dict = self.to_dict()
+ config_dict = self._dict_to_save
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
- def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
+ def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
- use_diff (`bool`, *optional*, defaults to `True`):
- If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
- is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
- writer.write(self.to_json_string(use_diff=use_diff))
-
- def update(self, config_dict: Dict[str, Any]):
- """
- Updates attributes of this class with attributes from `config_dict`.
-
- Args:
- config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
- """
- for key, value in config_dict.items():
- setattr(self, key, value)
-
- def update_from_string(self, update_str: str):
- """
- Updates attributes of this class with attributes from `update_str`.
-
- The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
- "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
-
- The keys to change have to already exist in the config object.
-
- Args:
- update_str (`str`): String with attributes that should be updated for this class.
-
- """
-
- d = dict(x.split("=") for x in update_str.split(","))
- for k, v in d.items():
- if not hasattr(self, k):
- raise ValueError(f"key {k} isn't in the original config dict")
-
- old_v = getattr(self, k)
- if isinstance(old_v, bool):
- if v.lower() in ["true", "1", "y", "yes"]:
- v = True
- elif v.lower() in ["false", "0", "n", "no"]:
- v = False
- else:
- raise ValueError(f"can't derive true or false from {v} (key {k})")
- elif isinstance(old_v, int):
- v = int(v)
- elif isinstance(old_v, float):
- v = float(v)
- elif not isinstance(old_v, str):
- raise ValueError(
- f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
- )
-
- setattr(self, k, v)
-
- def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
- """
- Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
- converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
- string, which can then be stored in the json format.
- """
- if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
- d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
- for value in d.values():
- if isinstance(value, dict):
- self.dict_torch_dtype_to_str(value)
+ writer.write(self.to_json_string())
diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py
index a7e2b3895f..d346fb7400 100644
--- a/src/diffusers/modeling_utils.py
+++ b/src/diffusers/modeling_utils.py
@@ -33,10 +33,9 @@ from transformers.utils import (
is_offline_mode,
is_remote_url,
logging,
+ CONFIG_NAME,
)
-from .configuration_utils import PretrainedConfig
-
WEIGHTS_NAME = "diffusion_model.pt"
@@ -135,7 +134,7 @@ class PreTrainedModel(torch.nn.Module):
Class attributes (overridden by derived classes):
- - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
+ - **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
@@ -150,35 +149,16 @@ class PreTrainedModel(torch.nn.Module):
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
- config_class = None
+ config_name = CONFIG_NAME
- def __init__(self, config: PretrainedConfig):
+ def __init__(self):
super().__init__()
- if not isinstance(config, PretrainedConfig):
- raise ValueError(
- f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
- "`PretrainedConfig`. To create a model from a pretrained model use "
- f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
- )
- # Save config and origin of the pretrained weights if given in model
- self.config = config
- self.name_or_path = config.name_or_path
-
- @classmethod
- def _from_config(cls, config, **kwargs):
- """
- All context managers that the model should be initialized under go here.
- """
- model = cls(config, **kwargs)
-
- return model
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = torch.save,
- push_to_hub: bool = False,
**kwargs,
):
"""
@@ -195,16 +175,6 @@ class PreTrainedModel(torch.nn.Module):
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.
- push_to_hub (`bool`, *optional*, defaults to `False`):
- Whether or not to push your model to the Hugging Face model hub after saving it.
-
-
-
- Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
- which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
- folder. Pass along `temp_dir=True` to use a temporary directory instead.
-
-
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
@@ -218,11 +188,9 @@ class PreTrainedModel(torch.nn.Module):
model_to_save = self
# Attach architecture to the config
- model_to_save.config.architectures = [model_to_save.__class__.__name__]
-
# Save the config
if is_main_process:
- model_to_save.config.save_pretrained(save_directory)
+ model_to_save.save_config(save_directory)
# Save the model
state_dict = model_to_save.state_dict()
@@ -241,7 +209,7 @@ class PreTrainedModel(torch.nn.Module):
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
@classmethod
- def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
+ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
@@ -265,11 +233,11 @@ class PreTrainedModel(torch.nn.Module):
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
- config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
+ config (`Union[Config, str, os.PathLike]`, *optional*):
Can be either:
- - an instance of a class derived from [`PretrainedConfig`],
- - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
+ - an instance of a class derived from [`Config`],
+ - a string or path valid as input to [`~Config.from_pretrained`].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
@@ -327,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
- initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
+ initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.
@@ -356,7 +324,6 @@ class PreTrainedModel(torch.nn.Module):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
- from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -367,7 +334,7 @@ class PreTrainedModel(torch.nn.Module):
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
- config, model_kwargs = cls.config_class.from_pretrained(
+ model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -377,12 +344,9 @@ class PreTrainedModel(torch.nn.Module):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
- _from_auto=from_auto_class,
- _from_pipeline=from_pipeline,
**kwargs,
)
- model_kwargs = kwargs
-
+ model.register(name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -456,18 +420,8 @@ class PreTrainedModel(torch.nn.Module):
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
- state_dict = load_state_dict(resolved_archive_file)
- # set dtype to instantiate the model under:
- # 1. If torch_dtype is not None, we use that dtype
- # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
- # weights entry - we assume all weights are of the same dtype
- # we also may have config.torch_dtype available, but we won't rely on it till v5
-
- config.name_or_path = pretrained_model_name_or_path
-
- model = cls(config, *model_args, **model_kwargs)
-
# restore default dtype
+ state_dict = load_state_dict(resolved_archive_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,
diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py
index 364653676d..54b1d3ead3 100644
--- a/src/diffusers/models/__init__.py
+++ b/src/diffusers/models/__init__.py
@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-from . import unet
+from .unet import UNetModel
diff --git a/src/diffusers/models/unet/modeling_unet.py b/src/diffusers/models/unet.py
similarity index 62%
rename from src/diffusers/models/unet/modeling_unet.py
rename to src/diffusers/models/unet.py
index 23c0921633..6b8069d1cb 100644
--- a/src/diffusers/models/unet/modeling_unet.py
+++ b/src/diffusers/models/unet.py
@@ -22,27 +22,22 @@ from inspect import isfunction
from pathlib import Path
import torch
-import torch.nn.functional as F
from torch import einsum, nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from einops import rearrange
-from PIL import Image
-from torchvision import transforms, utils
+from torchvision import utils, transforms
from tqdm import tqdm
-from ...modeling_utils import PreTrainedModel
-from .configuration_unet import UNetConfig
-
+from ..configuration_utils import Config
+from ..modeling_utils import PreTrainedModel
+from PIL import Image
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-
-
-
def exists(x):
return x is not None
@@ -246,13 +241,30 @@ class Attention(nn.Module):
return self.to_out(out)
-class UNetModel(PreTrainedModel):
-
- config_class = UNetConfig
-
- def __init__(self, config):
- super().__init__(config)
+class UNetModel(PreTrainedModel, Config):
+ def __init__(
+ self,
+ dim=64,
+ dim_mults=(1, 2, 4, 8),
+ init_dim=None,
+ out_dim=None,
+ channels=3,
+ with_time_emb=True,
+ resnet_block_groups=8,
+ learned_variance=False,
+ ):
+ super().__init__()
+ self.register(
+ dim=dim,
+ dim_mults=dim_mults,
+ init_dim=init_dim,
+ out_dim=out_dim,
+ channels=channels,
+ with_time_emb=with_time_emb,
+ resnet_block_groups=resnet_block_groups,
+ learned_variance=learned_variance,
+ )
init_dim = None
out_dim = None
channels = 3
@@ -262,9 +274,9 @@ class UNetModel(PreTrainedModel):
# determine dimensions
- dim_mults = config.dim_mults
- dim = config.dim
- self.channels = config.channels
+ dim_mults = dim_mults
+ dim = dim
+ self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
@@ -357,238 +369,6 @@ class UNetModel(PreTrainedModel):
return self.final_conv(x)
-# gaussian diffusion trainer class
-
-
-def extract(a, t, x_shape):
- b, *_ = t.shape
- out = a.gather(-1, t)
- return out.reshape(b, *((1,) * (len(x_shape) - 1)))
-
-
-def noise_like(shape, device, repeat=False):
- def repeat_noise():
- return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
-
- def noise():
- return torch.randn(shape, device=device)
-
- return repeat_noise() if repeat else noise()
-
-
-def linear_beta_schedule(timesteps):
- scale = 1000 / timesteps
- beta_start = scale * 0.0001
- beta_end = scale * 0.02
- return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
-
-
-def cosine_beta_schedule(timesteps, s=0.008):
- """
- cosine schedule
- as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
- """
- steps = timesteps + 1
- x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
- alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
- alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
- betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
- return torch.clip(betas, 0, 0.999)
-
-
-class GaussianDiffusion(nn.Module):
- def __init__(
- self,
- denoise_fn,
- *,
- image_size,
- channels=3,
- timesteps=1000,
- loss_type="l1",
- objective="pred_noise",
- beta_schedule="cosine",
- ):
- super().__init__()
- assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
-
- self.channels = channels
- self.image_size = image_size
- self.denoise_fn = denoise_fn
- self.objective = objective
-
- if beta_schedule == "linear":
- betas = linear_beta_schedule(timesteps)
- elif beta_schedule == "cosine":
- betas = cosine_beta_schedule(timesteps)
- else:
- raise ValueError(f"unknown beta schedule {beta_schedule}")
-
- alphas = 1.0 - betas
- alphas_cumprod = torch.cumprod(alphas, axis=0)
- alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
-
- (timesteps,) = betas.shape
- self.num_timesteps = int(timesteps)
- self.loss_type = loss_type
-
- # helper function to register buffer from float64 to float32
-
- def register_buffer(name, val):
- self.register_buffer(name, val.to(torch.float32))
-
- register_buffer("betas", betas)
- register_buffer("alphas_cumprod", alphas_cumprod)
- register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
-
- # calculations for diffusion q(x_t | x_{t-1}) and others
-
- register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
- register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
- register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
- register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
- register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
-
- # calculations for posterior q(x_{t-1} | x_t, x_0)
-
- posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
-
- # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
-
- register_buffer("posterior_variance", posterior_variance)
-
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
-
- register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
- register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
- register_buffer(
- "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
- )
-
- def predict_start_from_noise(self, x_t, t, noise):
- return (
- extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
- )
-
- def q_posterior(self, x_start, x_t, t):
- posterior_mean = (
- extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
- + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
- )
- posterior_variance = extract(self.posterior_variance, t, x_t.shape)
- posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
-
- def p_mean_variance(self, x, t, clip_denoised: bool):
- model_output = self.denoise_fn(x, t)
-
- if self.objective == "pred_noise":
- x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
- elif self.objective == "pred_x0":
- x_start = model_output
- else:
- raise ValueError(f"unknown objective {self.objective}")
-
- if clip_denoised:
- x_start.clamp_(-1.0, 1.0)
-
- model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
- return model_mean, posterior_variance, posterior_log_variance
-
- @torch.no_grad()
- def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
- b, *_, device = *x.shape, x.device
- model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
- noise = noise_like(x.shape, device, repeat_noise)
- # no noise when t == 0
- nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
- result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
- return result
-
- @torch.no_grad()
- def p_sample_loop(self, shape):
- device = self.betas.device
-
- b = shape[0]
- img = torch.randn(shape, device=device)
-
- for i in tqdm(
- reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
- ):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
-
- img = unnormalize_to_zero_to_one(img)
- return img
-
- @torch.no_grad()
- def sample(self, batch_size=16):
- image_size = self.image_size
- channels = self.channels
- return self.p_sample_loop((batch_size, channels, image_size, image_size))
-
- @torch.no_grad()
- def interpolate(self, x1, x2, t=None, lam=0.5):
- b, *_, device = *x1.shape, x1.device
- t = default(t, self.num_timesteps - 1)
-
- assert x1.shape == x2.shape
-
- t_batched = torch.stack([torch.tensor(t, device=device)] * b)
- xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
-
- img = (1 - lam) * xt1 + lam * xt2
- for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
- img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
-
- return img
-
- def q_sample(self, x_start, t, noise=None):
- noise = default(noise, lambda: torch.randn_like(x_start))
-
- return (
- extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
- + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
- )
-
- @property
- def loss_fn(self):
- if self.loss_type == "l1":
- return F.l1_loss
- elif self.loss_type == "l2":
- return F.mse_loss
- else:
- raise ValueError(f"invalid loss type {self.loss_type}")
-
- def p_losses(self, x_start, t, noise=None):
- b, c, h, w = x_start.shape
- noise = default(noise, lambda: torch.randn_like(x_start))
-
- x = self.q_sample(x_start=x_start, t=t, noise=noise)
- model_out = self.denoise_fn(x, t)
-
- if self.objective == "pred_noise":
- target = noise
- elif self.objective == "pred_x0":
- target = x_start
- else:
- raise ValueError(f"unknown objective {self.objective}")
-
- loss = self.loss_fn(model_out, target)
- return loss
-
- def forward(self, img, *args, **kwargs):
- b, _, h, w, device, img_size, = (
- *img.shape,
- img.device,
- self.image_size,
- )
- assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
- t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
-
- img = normalize_to_neg_one_to_one(img)
- return self.p_losses(img, t, *args, **kwargs)
-
-
# dataset classes
@@ -621,6 +401,7 @@ class Dataset(data.Dataset):
class Trainer(object):
+
def __init__(
self,
diffusion_model,
diff --git a/src/diffusers/models/unet/__init__.py b/src/diffusers/models/unet/__init__.py
deleted file mode 100644
index 01b2a02857..0000000000
--- a/src/diffusers/models/unet/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# flake8: noqa
-# There's no way to ignore "F401 '...' imported but unused" warnings in this
-# module, but to preserve other warnings. So, don't check this module at all
-
-from .configuration_unet import UNetConfig
-from .modeling_unet import GaussianDiffusion, UNetModel
diff --git a/src/diffusers/models/unet/configuration_unet.py b/src/diffusers/models/unet/configuration_unet.py
deleted file mode 100644
index a8a1e12c00..0000000000
--- a/src/diffusers/models/unet/configuration_unet.py
+++ /dev/null
@@ -1,45 +0,0 @@
-# Copyright 2022 The HuggingFace Team. All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-
-# limitations under the License.
-
-# helpers functions
-
-# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
-from ...configuration_utils import PretrainedConfig
-
-
-class UNetConfig(PretrainedConfig):
- model_type = "unet"
-
- def __init__(
- self,
- dim=64,
- dim_mults=(1, 2, 4, 8),
- init_dim=None,
- out_dim=None,
- channels=3,
- with_time_emb=True,
- resnet_block_groups=8,
- learned_variance=False,
- **kwargs,
- ):
- super().__init__(**kwargs)
- self.dim = dim
- self.dim_mults = dim_mults
- self.init_dim = init_dim
- self.out_dim = out_dim
- self.channels = channels
- self.with_time_emb = with_time_emb
- self.resnet_block_groups = resnet_block_groups
- self.learned_variance = learned_variance
diff --git a/src/diffusers/samplers/__init__.py b/src/diffusers/samplers/__init__.py
index 99977b417e..76aa8aab0c 100644
--- a/src/diffusers/samplers/__init__.py
+++ b/src/diffusers/samplers/__init__.py
@@ -1,3 +1,7 @@
+# flake8: noqa
+# There's no way to ignore "F401 '...' imported but unused" warnings in this
+# module, but to preserve other warnings. So, don't check this module at all.
+
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,3 +15,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+
+from .gaussian import GaussianDiffusion
diff --git a/src/diffusers/samplers/gaussian.py b/src/diffusers/samplers/gaussian.py
new file mode 100644
index 0000000000..4a7e350704
--- /dev/null
+++ b/src/diffusers/samplers/gaussian.py
@@ -0,0 +1,312 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+import torch.nn.functional as F
+from torch import nn
+from inspect import isfunction
+from tqdm import tqdm
+
+from ..configuration_utils import Config
+SAMPLING_CONFIG_NAME = "sampler_config.json"
+
+
+def exists(x):
+ return x is not None
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+
+def cycle(dl):
+ while True:
+ for data_dl in dl:
+ yield data_dl
+
+
+def num_to_groups(num, divisor):
+ groups = num // divisor
+ remainder = num % divisor
+ arr = [divisor] * groups
+ if remainder > 0:
+ arr.append(remainder)
+ return arr
+
+
+def normalize_to_neg_one_to_one(img):
+ return img * 2 - 1
+
+
+def unnormalize_to_zero_to_one(t):
+ return (t + 1) * 0.5
+
+
+# small helper modules
+
+
+class EMA:
+ def __init__(self, beta):
+ super().__init__()
+ self.beta = beta
+
+ def update_model_average(self, ma_model, current_model):
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
+ old_weight, up_weight = ma_params.data, current_params.data
+ ma_params.data = self.update_average(old_weight, up_weight)
+
+ def update_average(self, old, new):
+ if old is None:
+ return new
+ return old * self.beta + (1 - self.beta) * new
+
+
+# gaussian diffusion trainer class
+
+
+def extract(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+
+def noise_like(shape, device, repeat=False):
+ def repeat_noise():
+ return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+
+ def noise():
+ return torch.randn(shape, device=device)
+
+ return repeat_noise() if repeat else noise()
+
+
+def linear_beta_schedule(timesteps):
+ scale = 1000 / timesteps
+ beta_start = scale * 0.0001
+ beta_end = scale * 0.02
+ return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
+
+
+def cosine_beta_schedule(timesteps, s=0.008):
+ """
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ """
+ steps = timesteps + 1
+ x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
+ alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ return torch.clip(betas, 0, 0.999)
+
+
+class GaussianDiffusion(nn.Module, Config):
+
+ config_name = SAMPLING_CONFIG_NAME
+
+ def __init__(
+ self,
+ image_size,
+ channels=3,
+ timesteps=1000,
+ loss_type="l1",
+ objective="pred_noise",
+ beta_schedule="cosine",
+ ):
+ super().__init__()
+ self.register(
+ image_size=image_size,
+ channels=channels,
+ timesteps=timesteps,
+ loss_type=loss_type,
+ objective=objective,
+ beta_schedule=beta_schedule,
+ )
+
+ self.channels = channels
+ self.image_size = image_size
+ self.objective = objective
+
+ if beta_schedule == "linear":
+ betas = linear_beta_schedule(timesteps)
+ elif beta_schedule == "cosine":
+ betas = cosine_beta_schedule(timesteps)
+ else:
+ raise ValueError(f"unknown beta schedule {beta_schedule}")
+
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
+
+ (timesteps,) = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.loss_type = loss_type
+
+ # helper function to register buffer from float64 to float32
+
+ def register_buffer(name, val):
+ self.register_buffer(name, val.to(torch.float32))
+
+ register_buffer("betas", betas)
+ register_buffer("alphas_cumprod", alphas_cumprod)
+ register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+
+ register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
+ register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
+ register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
+ register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
+ register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+
+ posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
+
+ # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
+
+ register_buffer("posterior_variance", posterior_variance)
+
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+
+ register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
+ register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
+ register_buffer(
+ "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
+ )
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+
+ def q_posterior(self, x_start, x_t, t):
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised: bool):
+ model_output = model(x, t)
+
+ if self.objective == "pred_noise":
+ x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
+ elif self.objective == "pred_x0":
+ x_start = model_output
+ else:
+ raise ValueError(f"unknown objective {self.objective}")
+
+ if clip_denoised:
+ x_start.clamp_(-1.0, 1.0)
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, model, x, t, clip_denoised=True, repeat_noise=False):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
+ noise = noise_like(x.shape, device, repeat_noise)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
+ result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+ return result
+
+ @torch.no_grad()
+ def p_sample_loop(self, model, shape):
+ device = self.betas.device
+
+ b = shape[0]
+ img = torch.randn(shape, device=device)
+
+ for i in tqdm(
+ reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
+ ):
+ img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
+
+ img = unnormalize_to_zero_to_one(img)
+ return img
+
+ @torch.no_grad()
+ def sample(self, model, batch_size=16):
+ image_size = self.image_size
+ channels = self.channels
+ return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
+
+ @torch.no_grad()
+ def interpolate(self, model, x1, x2, t=None, lam=0.5):
+ b, *_, device = *x1.shape, x1.device
+ t = default(t, self.num_timesteps - 1)
+
+ assert x1.shape == x2.shape
+
+ t_batched = torch.stack([torch.tensor(t, device=device)] * b)
+ xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
+
+ img = (1 - lam) * xt1 + lam * xt2
+ for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
+ img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
+
+ return img
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ return (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ @property
+ def loss_fn(self):
+ if self.loss_type == "l1":
+ return F.l1_loss
+ elif self.loss_type == "l2":
+ return F.mse_loss
+ else:
+ raise ValueError(f"invalid loss type {self.loss_type}")
+
+ def p_losses(self, model, x_start, t, noise=None):
+ b, c, h, w = x_start.shape
+ noise = default(noise, lambda: torch.randn_like(x_start))
+
+ x = self.q_sample(x_start=x_start, t=t, noise=noise)
+ model_out = model(x, t)
+
+ if self.objective == "pred_noise":
+ target = noise
+ elif self.objective == "pred_x0":
+ target = x_start
+ else:
+ raise ValueError(f"unknown objective {self.objective}")
+
+ loss = self.loss_fn(model_out, target)
+ return loss
+
+ def forward(self, model, img, *args, **kwargs):
+ b, _, h, w, device, img_size, = (
+ *img.shape,
+ img.device,
+ self.image_size,
+ )
+ assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
+ t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
+
+ img = normalize_to_neg_one_to_one(img)
+ return self.p_losses(model, img, t, *args, **kwargs)
diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py
index 92612f9831..1233980f39 100755
--- a/tests/test_modeling_utils.py
+++ b/tests/test_modeling_utils.py
@@ -19,7 +19,7 @@ import unittest
import torch
-from diffusers import UNetConfig, UNetModel
+from diffusers import GaussianDiffusion, UNetModel
global_rng = random.Random()
@@ -42,7 +42,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class ModelTesterMixin(unittest.TestCase):
-
@property
def dummy_input(self):
batch_size = 1
@@ -55,8 +54,7 @@ class ModelTesterMixin(unittest.TestCase):
return (noise, time_step)
def test_from_pretrained_save_pretrained(self):
- config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
- model = UNetModel(config)
+ model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@@ -75,3 +73,34 @@ class ModelTesterMixin(unittest.TestCase):
image = model(*self.dummy_input)
assert image is not None, "Make sure output is not None"
+
+
+class SamplerTesterMixin(unittest.TestCase):
+
+ @property
+ def dummy_model(self):
+ return UNetModel.from_pretrained("fusing/ddpm_dummy")
+
+ def test_from_pretrained_save_pretrained(self):
+ sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1")
+
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ sampler.save_config(tmpdirname)
+ new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False)
+
+ model = self.dummy_model
+
+ torch.manual_seed(0)
+ sampled_out = sampler.sample(model, batch_size=1)
+ torch.manual_seed(0)
+ sampled_out_new = new_sampler.sample(model, batch_size=1)
+
+ assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output"
+
+ def test_from_pretrained_hub(self):
+ sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
+ model = self.dummy_model
+
+ sampled_out = sampler.sample(model, batch_size=1)
+
+ assert sampled_out is not None, "Make sure output is not None"