mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add pretrained model and pretrained sampler
This commit is contained in:
@@ -1 +0,0 @@
|
||||
ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py
|
||||
@@ -1,19 +1,10 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDiffusion, UNetConfig, UNetModel
|
||||
from diffusers import GaussianDiffusion, UNetModel
|
||||
|
||||
|
||||
config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
model = UNetModel(config)
|
||||
print(model.config)
|
||||
|
||||
model.save_pretrained("/home/patrick/diffusion_example")
|
||||
|
||||
import ipdb
|
||||
|
||||
|
||||
ipdb.set_trace()
|
||||
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
|
||||
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
Metadata-Version: 2.1
|
||||
Name: diffusers
|
||||
Version: 0.0.1
|
||||
Summary: Diffusers
|
||||
Home-page: https://github.com/huggingface/diffusers
|
||||
Author: The HuggingFace team
|
||||
Author-email: patrick@huggingface.co
|
||||
License: Apache
|
||||
Keywords: deep learning
|
||||
Classifier: Development Status :: 5 - Production/Stable
|
||||
Classifier: Intended Audience :: Developers
|
||||
Classifier: Intended Audience :: Education
|
||||
Classifier: Intended Audience :: Science/Research
|
||||
Classifier: License :: OSI Approved :: Apache Software License
|
||||
Classifier: Operating System :: OS Independent
|
||||
Classifier: Programming Language :: Python :: 3
|
||||
Classifier: Programming Language :: Python :: 3.6
|
||||
Classifier: Programming Language :: Python :: 3.7
|
||||
Classifier: Programming Language :: Python :: 3.8
|
||||
Classifier: Programming Language :: Python :: 3.9
|
||||
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
||||
Requires-Python: >=3.6.0
|
||||
Description-Content-Type: text/markdown
|
||||
Provides-Extra: quality
|
||||
Provides-Extra: docs
|
||||
Provides-Extra: test
|
||||
Provides-Extra: dev
|
||||
Provides-Extra: sagemaker
|
||||
License-File: LICENSE
|
||||
|
||||
Super cool library about diffusion models
|
||||
@@ -1,14 +0,0 @@
|
||||
LICENSE
|
||||
README.md
|
||||
pyproject.toml
|
||||
setup.cfg
|
||||
setup.py
|
||||
src/diffusers/__init__.py
|
||||
src/diffusers.egg-info/PKG-INFO
|
||||
src/diffusers.egg-info/SOURCES.txt
|
||||
src/diffusers.egg-info/dependency_links.txt
|
||||
src/diffusers.egg-info/requires.txt
|
||||
src/diffusers.egg-info/top_level.txt
|
||||
src/diffusers/models/__init__.py
|
||||
src/diffusers/models/unet.py
|
||||
src/diffusers/samplers/__init__.py
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
numpy>=1.17
|
||||
packaging>=20.0
|
||||
pyyaml
|
||||
torch>=1.4.0
|
||||
|
||||
[dev]
|
||||
black~=22.0
|
||||
isort>=5.5.4
|
||||
flake8>=3.8.3
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-subtests
|
||||
datasets
|
||||
transformers
|
||||
|
||||
[docs]
|
||||
|
||||
[quality]
|
||||
black~=22.0
|
||||
isort>=5.5.4
|
||||
flake8>=3.8.3
|
||||
|
||||
[sagemaker]
|
||||
sagemaker
|
||||
|
||||
[test]
|
||||
pytest
|
||||
pytest-xdist
|
||||
pytest-subtests
|
||||
datasets
|
||||
transformers
|
||||
@@ -1 +0,0 @@
|
||||
diffusers
|
||||
@@ -4,5 +4,5 @@
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else
|
||||
from .models.unet import UNetConfig, UNetModel
|
||||
from .models.unet import UNetModel
|
||||
from .samplers.gaussian import GaussianDiffusion
|
||||
|
||||
329
src/diffusers/configuration_utils.py
Executable file → Normal file
329
src/diffusers/configuration_utils.py
Executable file → Normal file
@@ -20,11 +20,11 @@ import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import inspect
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
@@ -44,33 +44,34 @@ logger = logging.get_logger(__name__)
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class PretrainedConfig:
|
||||
class Config:
|
||||
r"""
|
||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||
methods for loading/downloading/saving configurations.
|
||||
|
||||
"""
|
||||
model_type: str = ""
|
||||
config_name = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
def register(self, **kwargs):
|
||||
if self.config_name is None:
|
||||
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
||||
kwargs["_class_name"] = self.__class__.__name__
|
||||
for key, value in kwargs.items():
|
||||
try:
|
||||
setattr(self, key, value)
|
||||
except AttributeError as err:
|
||||
logger.error(f"Can't set {key} with value {value} for {self}")
|
||||
raise err
|
||||
|
||||
# Drop the diffusers version info
|
||||
self.diffusers_version = kwargs.pop("diffusers_version", None)
|
||||
if not hasattr(self, "_dict_to_save"):
|
||||
self._dict_to_save = {}
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return getattr(self, "_name_or_path", None)
|
||||
self._dict_to_save.update(kwargs)
|
||||
|
||||
@name_or_path.setter
|
||||
def name_or_path(self, value):
|
||||
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~PretrainedConfig.from_pretrained`] class method.
|
||||
[`~Config.from_config`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
@@ -83,119 +84,15 @@ class PretrainedConfig:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
# If we save using the predefined names, we can load using `from_config`
|
||||
output_config_file = os.path.join(save_directory, self.config_name)
|
||||
|
||||
self.to_json_file(output_config_file, use_diff=True)
|
||||
self.to_json_file(output_config_file)
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
r"""
|
||||
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
|
||||
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
- a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
||||
they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
|
||||
exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final configuration object.
|
||||
|
||||
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
||||
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
||||
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||
by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
|
||||
# derived class: BertConfig
|
||||
config = BertConfig.from_pretrained(
|
||||
"bert-base-uncased"
|
||||
) # Download configuration from huggingface.co and cache.
|
||||
config = BertConfig.from_pretrained(
|
||||
"./test/saved_model/"
|
||||
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
|
||||
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
|
||||
config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
||||
assert config.output_attentions == True
|
||||
config, unused_kwargs = BertConfig.from_pretrained(
|
||||
"bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
||||
)
|
||||
assert config.output_attentions == True
|
||||
assert unused_kwargs == {"foo": False}
|
||||
```"""
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
[`PretrainedConfig`] using `from_dict`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
||||
Returns:
|
||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
||||
|
||||
"""
|
||||
# Get config dict associated with the base config file
|
||||
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def _get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
def from_config(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
@@ -215,7 +112,7 @@ class PretrainedConfig:
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
configuration_file = cls.config_name
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
@@ -286,58 +183,27 @@ class PretrainedConfig:
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict, kwargs
|
||||
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
|
||||
expected_keys.remove("self")
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
|
||||
passed_keys = set(config_dict.keys())
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
||||
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
unused_kwargs = kwargs
|
||||
for key in passed_keys - expected_keys:
|
||||
unused_kwargs[key] = config_dict.pop(key)
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
# Those arguments may be passed along for our internal telemetry.
|
||||
# We remove them so they don't appear in `return_unused_kwargs`.
|
||||
if len(expected_keys - passed_keys) > 0:
|
||||
logger.warn(
|
||||
f"{expected_keys - passed_keys} was not found in config. "
|
||||
f"Values will be initialized to default values."
|
||||
)
|
||||
|
||||
config = cls(**config_dict)
|
||||
model = cls(**config_dict)
|
||||
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info(f"Model config {config}")
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
return model, unused_kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
|
||||
|
||||
Args:
|
||||
json_file (`str` or `os.PathLike`):
|
||||
Path to the JSON file containing the parameters.
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
|
||||
|
||||
"""
|
||||
config_dict = cls._dict_from_json_file(json_file)
|
||||
return cls(**config_dict)
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
@@ -351,38 +217,6 @@ class PretrainedConfig:
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def to_diff_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes all attributes from config which correspond to the default config attributes for better readability and
|
||||
serializes to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
|
||||
# get the default config dict
|
||||
default_config_dict = PretrainedConfig().to_dict()
|
||||
|
||||
# get class specific config dict
|
||||
class_config_dict = self.__class__().to_dict()
|
||||
|
||||
serializable_config_dict = {}
|
||||
|
||||
# only serialize values that differ from the default config
|
||||
for key, value in config_dict.items():
|
||||
if (
|
||||
key not in default_config_dict
|
||||
or key == "diffusers_version"
|
||||
or value != default_config_dict[key]
|
||||
or (key in class_config_dict and value != class_config_dict[key])
|
||||
):
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
@@ -391,106 +225,29 @@ class PretrainedConfig:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
if hasattr(self.__class__, "model_type"):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
if "_auto_class" in output:
|
||||
del output["_auto_class"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
# Diffusion version when serializing the model
|
||||
output["diffusers_version"] = __version__
|
||||
|
||||
self.dict_torch_dtype_to_str(output)
|
||||
|
||||
return output
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
def to_json_string(self) -> str:
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
|
||||
Args:
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
||||
is serialized to JSON string.
|
||||
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
if use_diff is True:
|
||||
config_dict = self.to_diff_dict()
|
||||
else:
|
||||
config_dict = self.to_dict()
|
||||
config_dict = self._dict_to_save
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
||||
is serialized to JSON file.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string(use_diff=use_diff))
|
||||
|
||||
def update(self, config_dict: Dict[str, Any]):
|
||||
"""
|
||||
Updates attributes of this class with attributes from `config_dict`.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
||||
"""
|
||||
for key, value in config_dict.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def update_from_string(self, update_str: str):
|
||||
"""
|
||||
Updates attributes of this class with attributes from `update_str`.
|
||||
|
||||
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||
|
||||
The keys to change have to already exist in the config object.
|
||||
|
||||
Args:
|
||||
update_str (`str`): String with attributes that should be updated for this class.
|
||||
|
||||
"""
|
||||
|
||||
d = dict(x.split("=") for x in update_str.split(","))
|
||||
for k, v in d.items():
|
||||
if not hasattr(self, k):
|
||||
raise ValueError(f"key {k} isn't in the original config dict")
|
||||
|
||||
old_v = getattr(self, k)
|
||||
if isinstance(old_v, bool):
|
||||
if v.lower() in ["true", "1", "y", "yes"]:
|
||||
v = True
|
||||
elif v.lower() in ["false", "0", "n", "no"]:
|
||||
v = False
|
||||
else:
|
||||
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
||||
elif isinstance(old_v, int):
|
||||
v = int(v)
|
||||
elif isinstance(old_v, float):
|
||||
v = float(v)
|
||||
elif not isinstance(old_v, str):
|
||||
raise ValueError(
|
||||
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
||||
)
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
||||
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
||||
string, which can then be stored in the json format.
|
||||
"""
|
||||
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
for value in d.values():
|
||||
if isinstance(value, dict):
|
||||
self.dict_torch_dtype_to_str(value)
|
||||
writer.write(self.to_json_string())
|
||||
|
||||
@@ -33,10 +33,9 @@ from transformers.utils import (
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
CONFIG_NAME,
|
||||
)
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_model.pt"
|
||||
|
||||
@@ -135,7 +134,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
|
||||
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
|
||||
for this model architecture.
|
||||
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
||||
taking as arguments:
|
||||
@@ -150,35 +149,16 @@ class PreTrainedModel(torch.nn.Module):
|
||||
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
|
||||
models, `pixel_values` for vision models and `input_values` for speech models).
|
||||
"""
|
||||
config_class = None
|
||||
config_name = CONFIG_NAME
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`PretrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
# Save config and origin of the pretrained weights if given in model
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
"""
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = torch.save,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
@@ -195,16 +175,6 @@ class PreTrainedModel(torch.nn.Module):
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
@@ -218,11 +188,9 @@ class PreTrainedModel(torch.nn.Module):
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
model_to_save.save_config(save_directory)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
@@ -241,7 +209,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
|
||||
@@ -265,11 +233,11 @@ class PreTrainedModel(torch.nn.Module):
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
|
||||
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
|
||||
config (`Union[Config, str, os.PathLike]`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from [`PretrainedConfig`],
|
||||
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
|
||||
- an instance of a class derived from [`Config`],
|
||||
- a string or path valid as input to [`~Config.from_pretrained`].
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
@@ -327,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
||||
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
||||
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
|
||||
corresponds to a configuration attribute will be used to override said attribute with the
|
||||
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
||||
will be passed to the underlying model's `__init__` function.
|
||||
@@ -356,7 +324,6 @@ class PreTrainedModel(torch.nn.Module):
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
@@ -367,7 +334,7 @@ class PreTrainedModel(torch.nn.Module):
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
model, unused_kwargs = cls.from_config(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
@@ -377,12 +344,9 @@ class PreTrainedModel(torch.nn.Module):
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
model_kwargs = kwargs
|
||||
|
||||
model.register(name_or_path=pretrained_model_name_or_path)
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
@@ -456,18 +420,8 @@ class PreTrainedModel(torch.nn.Module):
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
# weights entry - we assume all weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# restore default dtype
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import unet
|
||||
from .unet import UNetModel
|
||||
|
||||
@@ -22,27 +22,22 @@ from inspect import isfunction
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum, nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision import transforms, utils
|
||||
from torchvision import utils, transforms
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from .configuration_unet import UNetConfig
|
||||
|
||||
from ..configuration_utils import Config
|
||||
from ..modeling_utils import PreTrainedModel
|
||||
from PIL import Image
|
||||
|
||||
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
@@ -246,13 +241,30 @@ class Attention(nn.Module):
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class UNetModel(PreTrainedModel):
|
||||
|
||||
config_class = UNetConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
class UNetModel(PreTrainedModel, Config):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=64,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
init_dim=None,
|
||||
out_dim=None,
|
||||
channels=3,
|
||||
with_time_emb=True,
|
||||
resnet_block_groups=8,
|
||||
learned_variance=False,
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
dim=dim,
|
||||
dim_mults=dim_mults,
|
||||
init_dim=init_dim,
|
||||
out_dim=out_dim,
|
||||
channels=channels,
|
||||
with_time_emb=with_time_emb,
|
||||
resnet_block_groups=resnet_block_groups,
|
||||
learned_variance=learned_variance,
|
||||
)
|
||||
init_dim = None
|
||||
out_dim = None
|
||||
channels = 3
|
||||
@@ -262,9 +274,9 @@ class UNetModel(PreTrainedModel):
|
||||
|
||||
# determine dimensions
|
||||
|
||||
dim_mults = config.dim_mults
|
||||
dim = config.dim
|
||||
self.channels = config.channels
|
||||
dim_mults = dim_mults
|
||||
dim = dim
|
||||
self.channels = channels
|
||||
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
|
||||
@@ -357,238 +369,6 @@ class UNetModel(PreTrainedModel):
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
# gaussian diffusion trainer class
|
||||
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
denoise_fn,
|
||||
*,
|
||||
image_size,
|
||||
channels=3,
|
||||
timesteps=1000,
|
||||
loss_type="l1",
|
||||
objective="pred_noise",
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
|
||||
|
||||
self.channels = channels
|
||||
self.image_size = image_size
|
||||
self.denoise_fn = denoise_fn
|
||||
self.objective = objective
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
else:
|
||||
raise ValueError(f"unknown beta schedule {beta_schedule}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
# helper function to register buffer from float64 to float32
|
||||
|
||||
def register_buffer(name, val):
|
||||
self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer("betas", betas)
|
||||
register_buffer("alphas_cumprod", alphas_cumprod)
|
||||
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
||||
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
||||
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
||||
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
||||
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
register_buffer("posterior_variance", posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
|
||||
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
|
||||
register_buffer(
|
||||
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, clip_denoised: bool):
|
||||
model_output = self.denoise_fn(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
|
||||
elif self.objective == "pred_x0":
|
||||
x_start = model_output
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
if clip_denoised:
|
||||
x_start.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(
|
||||
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
|
||||
):
|
||||
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
img = unnormalize_to_zero_to_one(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, batch_size=16):
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size))
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate(self, x1, x2, t=None, lam=0.5):
|
||||
b, *_, device = *x1.shape, x1.device
|
||||
t = default(t, self.num_timesteps - 1)
|
||||
|
||||
assert x1.shape == x2.shape
|
||||
|
||||
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
||||
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
||||
|
||||
img = (1 - lam) * xt1 + lam * xt2
|
||||
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
|
||||
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
@property
|
||||
def loss_fn(self):
|
||||
if self.loss_type == "l1":
|
||||
return F.l1_loss
|
||||
elif self.loss_type == "l2":
|
||||
return F.mse_loss
|
||||
else:
|
||||
raise ValueError(f"invalid loss type {self.loss_type}")
|
||||
|
||||
def p_losses(self, x_start, t, noise=None):
|
||||
b, c, h, w = x_start.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
model_out = self.denoise_fn(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
target = noise
|
||||
elif self.objective == "pred_x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
loss = self.loss_fn(model_out, target)
|
||||
return loss
|
||||
|
||||
def forward(self, img, *args, **kwargs):
|
||||
b, _, h, w, device, img_size, = (
|
||||
*img.shape,
|
||||
img.device,
|
||||
self.image_size,
|
||||
)
|
||||
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
|
||||
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
||||
|
||||
img = normalize_to_neg_one_to_one(img)
|
||||
return self.p_losses(img, t, *args, **kwargs)
|
||||
|
||||
|
||||
# dataset classes
|
||||
|
||||
|
||||
@@ -621,6 +401,7 @@ class Dataset(data.Dataset):
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
@@ -1,6 +0,0 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
|
||||
from .configuration_unet import UNetConfig
|
||||
from .modeling_unet import GaussianDiffusion, UNetModel
|
||||
@@ -1,45 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# helpers functions
|
||||
|
||||
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class UNetConfig(PretrainedConfig):
|
||||
model_type = "unet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=64,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
init_dim=None,
|
||||
out_dim=None,
|
||||
channels=3,
|
||||
with_time_emb=True,
|
||||
resnet_block_groups=8,
|
||||
learned_variance=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.dim = dim
|
||||
self.dim_mults = dim_mults
|
||||
self.init_dim = init_dim
|
||||
self.out_dim = out_dim
|
||||
self.channels = channels
|
||||
self.with_time_emb = with_time_emb
|
||||
self.resnet_block_groups = resnet_block_groups
|
||||
self.learned_variance = learned_variance
|
||||
@@ -1,3 +1,7 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -11,3 +15,5 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .gaussian import GaussianDiffusion
|
||||
|
||||
312
src/diffusers/samplers/gaussian.py
Normal file
312
src/diffusers/samplers/gaussian.py
Normal file
@@ -0,0 +1,312 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from inspect import isfunction
|
||||
from tqdm import tqdm
|
||||
|
||||
from ..configuration_utils import Config
|
||||
SAMPLING_CONFIG_NAME = "sampler_config.json"
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data_dl in dl:
|
||||
yield data_dl
|
||||
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
|
||||
def normalize_to_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
|
||||
def unnormalize_to_zero_to_one(t):
|
||||
return (t + 1) * 0.5
|
||||
|
||||
|
||||
# small helper modules
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
# gaussian diffusion trainer class
|
||||
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module, Config):
|
||||
|
||||
config_name = SAMPLING_CONFIG_NAME
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
image_size,
|
||||
channels=3,
|
||||
timesteps=1000,
|
||||
loss_type="l1",
|
||||
objective="pred_noise",
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
self.register(
|
||||
image_size=image_size,
|
||||
channels=channels,
|
||||
timesteps=timesteps,
|
||||
loss_type=loss_type,
|
||||
objective=objective,
|
||||
beta_schedule=beta_schedule,
|
||||
)
|
||||
|
||||
self.channels = channels
|
||||
self.image_size = image_size
|
||||
self.objective = objective
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
else:
|
||||
raise ValueError(f"unknown beta schedule {beta_schedule}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
# helper function to register buffer from float64 to float32
|
||||
|
||||
def register_buffer(name, val):
|
||||
self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer("betas", betas)
|
||||
register_buffer("alphas_cumprod", alphas_cumprod)
|
||||
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
||||
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
||||
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
||||
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
||||
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
register_buffer("posterior_variance", posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
|
||||
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
|
||||
register_buffer(
|
||||
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, model, x, t, clip_denoised: bool):
|
||||
model_output = model(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
|
||||
elif self.objective == "pred_x0":
|
||||
x_start = model_output
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
if clip_denoised:
|
||||
x_start.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, model, x, t, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, model, shape):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(
|
||||
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
|
||||
):
|
||||
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
img = unnormalize_to_zero_to_one(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, model, batch_size=16):
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate(self, model, x1, x2, t=None, lam=0.5):
|
||||
b, *_, device = *x1.shape, x1.device
|
||||
t = default(t, self.num_timesteps - 1)
|
||||
|
||||
assert x1.shape == x2.shape
|
||||
|
||||
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
||||
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
||||
|
||||
img = (1 - lam) * xt1 + lam * xt2
|
||||
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
|
||||
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
@property
|
||||
def loss_fn(self):
|
||||
if self.loss_type == "l1":
|
||||
return F.l1_loss
|
||||
elif self.loss_type == "l2":
|
||||
return F.mse_loss
|
||||
else:
|
||||
raise ValueError(f"invalid loss type {self.loss_type}")
|
||||
|
||||
def p_losses(self, model, x_start, t, noise=None):
|
||||
b, c, h, w = x_start.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
model_out = model(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
target = noise
|
||||
elif self.objective == "pred_x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
loss = self.loss_fn(model_out, target)
|
||||
return loss
|
||||
|
||||
def forward(self, model, img, *args, **kwargs):
|
||||
b, _, h, w, device, img_size, = (
|
||||
*img.shape,
|
||||
img.device,
|
||||
self.image_size,
|
||||
)
|
||||
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
|
||||
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
||||
|
||||
img = normalize_to_neg_one_to_one(img)
|
||||
return self.p_losses(model, img, t, *args, **kwargs)
|
||||
@@ -19,7 +19,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNetConfig, UNetModel
|
||||
from diffusers import GaussianDiffusion, UNetModel
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
@@ -42,7 +42,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
|
||||
|
||||
class ModelTesterMixin(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 1
|
||||
@@ -55,8 +54,7 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
return (noise, time_step)
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
|
||||
model = UNetModel(config)
|
||||
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
@@ -75,3 +73,34 @@ class ModelTesterMixin(unittest.TestCase):
|
||||
image = model(*self.dummy_input)
|
||||
|
||||
assert image is not None, "Make sure output is not None"
|
||||
|
||||
|
||||
class SamplerTesterMixin(unittest.TestCase):
|
||||
|
||||
@property
|
||||
def dummy_model(self):
|
||||
return UNetModel.from_pretrained("fusing/ddpm_dummy")
|
||||
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
sampler.save_config(tmpdirname)
|
||||
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False)
|
||||
|
||||
model = self.dummy_model
|
||||
|
||||
torch.manual_seed(0)
|
||||
sampled_out = sampler.sample(model, batch_size=1)
|
||||
torch.manual_seed(0)
|
||||
sampled_out_new = new_sampler.sample(model, batch_size=1)
|
||||
|
||||
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output"
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
|
||||
model = self.dummy_model
|
||||
|
||||
sampled_out = sampler.sample(model, batch_size=1)
|
||||
|
||||
assert sampled_out is not None, "Make sure output is not None"
|
||||
|
||||
Reference in New Issue
Block a user