1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

add pretrained model and pretrained sampler

This commit is contained in:
Patrick von Platen
2022-06-02 00:25:48 +02:00
parent 18ef809c4d
commit 8cb5e69415
17 changed files with 443 additions and 743 deletions

View File

@@ -1 +0,0 @@
ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py

View File

@@ -1,19 +1,10 @@
#!/usr/bin/env python3
import torch
from diffusers import GaussianDiffusion, UNetConfig, UNetModel
from diffusers import GaussianDiffusion, UNetModel
config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8))
model = UNetModel(config)
print(model.config)
model.save_pretrained("/home/patrick/diffusion_example")
import ipdb
ipdb.set_trace()
model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8))
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2

View File

@@ -1,31 +0,0 @@
Metadata-Version: 2.1
Name: diffusers
Version: 0.0.1
Summary: Diffusers
Home-page: https://github.com/huggingface/diffusers
Author: The HuggingFace team
Author-email: patrick@huggingface.co
License: Apache
Keywords: deep learning
Classifier: Development Status :: 5 - Production/Stable
Classifier: Intended Audience :: Developers
Classifier: Intended Audience :: Education
Classifier: Intended Audience :: Science/Research
Classifier: License :: OSI Approved :: Apache Software License
Classifier: Operating System :: OS Independent
Classifier: Programming Language :: Python :: 3
Classifier: Programming Language :: Python :: 3.6
Classifier: Programming Language :: Python :: 3.7
Classifier: Programming Language :: Python :: 3.8
Classifier: Programming Language :: Python :: 3.9
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
Requires-Python: >=3.6.0
Description-Content-Type: text/markdown
Provides-Extra: quality
Provides-Extra: docs
Provides-Extra: test
Provides-Extra: dev
Provides-Extra: sagemaker
License-File: LICENSE
Super cool library about diffusion models

View File

@@ -1,14 +0,0 @@
LICENSE
README.md
pyproject.toml
setup.cfg
setup.py
src/diffusers/__init__.py
src/diffusers.egg-info/PKG-INFO
src/diffusers.egg-info/SOURCES.txt
src/diffusers.egg-info/dependency_links.txt
src/diffusers.egg-info/requires.txt
src/diffusers.egg-info/top_level.txt
src/diffusers/models/__init__.py
src/diffusers/models/unet.py
src/diffusers/samplers/__init__.py

View File

@@ -1 +0,0 @@

View File

@@ -1,31 +0,0 @@
numpy>=1.17
packaging>=20.0
pyyaml
torch>=1.4.0
[dev]
black~=22.0
isort>=5.5.4
flake8>=3.8.3
pytest
pytest-xdist
pytest-subtests
datasets
transformers
[docs]
[quality]
black~=22.0
isort>=5.5.4
flake8>=3.8.3
[sagemaker]
sagemaker
[test]
pytest
pytest-xdist
pytest-subtests
datasets
transformers

View File

@@ -1 +0,0 @@
diffusers

View File

@@ -4,5 +4,5 @@
__version__ = "0.0.1"
from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else
from .models.unet import UNetConfig, UNetModel
from .models.unet import UNetModel
from .samplers.gaussian import GaussianDiffusion

329
src/diffusers/configuration_utils.py Executable file → Normal file
View File

@@ -20,11 +20,11 @@ import copy
import json
import os
import re
import inspect
from typing import Any, Dict, Tuple, Union
from requests import HTTPError
from transformers.utils import (
CONFIG_NAME,
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
EntryNotFoundError,
RepositoryNotFoundError,
@@ -44,33 +44,34 @@ logger = logging.get_logger(__name__)
_re_configuration_file = re.compile(r"config\.(.*)\.json")
class PretrainedConfig:
class Config:
r"""
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
methods for loading/downloading/saving configurations.
"""
model_type: str = ""
config_name = None
def __init__(self, **kwargs):
# Name or path to the pretrained checkpoint
self._name_or_path = str(kwargs.pop("name_or_path", ""))
def register(self, **kwargs):
if self.config_name is None:
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
kwargs["_class_name"] = self.__class__.__name__
for key, value in kwargs.items():
try:
setattr(self, key, value)
except AttributeError as err:
logger.error(f"Can't set {key} with value {value} for {self}")
raise err
# Drop the diffusers version info
self.diffusers_version = kwargs.pop("diffusers_version", None)
if not hasattr(self, "_dict_to_save"):
self._dict_to_save = {}
@property
def name_or_path(self) -> str:
return getattr(self, "_name_or_path", None)
self._dict_to_save.update(kwargs)
@name_or_path.setter
def name_or_path(self, value):
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
"""
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
[`~PretrainedConfig.from_pretrained`] class method.
[`~Config.from_config`] class method.
Args:
save_directory (`str` or `os.PathLike`):
@@ -83,119 +84,15 @@ class PretrainedConfig:
os.makedirs(save_directory, exist_ok=True)
# If we save using the predefined names, we can load using `from_pretrained`
output_config_file = os.path.join(save_directory, CONFIG_NAME)
# If we save using the predefined names, we can load using `from_config`
output_config_file = os.path.join(save_directory, self.config_name)
self.to_json_file(output_config_file, use_diff=True)
self.to_json_file(output_config_file)
logger.info(f"Configuration saved in {output_config_file}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
r"""
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
This can be either:
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
- a path to a *directory* containing a configuration file saved using the
[`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
- a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
cache_dir (`str` or `os.PathLike`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force to (re-)download the configuration files and override the cached versions if
they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
If `False`, then this function returns just the final configuration object.
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
kwargs (`Dict[str, Any]`, *optional*):
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
by the `return_unused_kwargs` keyword parameter.
<Tip>
Passing `use_auth_token=True` is required when you want to use a private model.
</Tip>
Returns:
[`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
Examples:
```python
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
# derived class: BertConfig
config = BertConfig.from_pretrained(
"bert-base-uncased"
) # Download configuration from huggingface.co and cache.
config = BertConfig.from_pretrained(
"./test/saved_model/"
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
assert config.output_attentions == True
config, unused_kwargs = BertConfig.from_pretrained(
"bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
)
assert config.output_attentions == True
assert unused_kwargs == {"foo": False}
```"""
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
logger.warning(
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
)
return cls.from_dict(config_dict, **kwargs)
@classmethod
def get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
[`PretrainedConfig`] using `from_dict`.
Parameters:
pretrained_model_name_or_path (`str` or `os.PathLike`):
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
Returns:
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
"""
# Get config dict associated with the base config file
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
return config_dict, kwargs
@classmethod
def _get_config_dict(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
def from_config(
cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
cache_dir = kwargs.pop("cache_dir", None)
force_download = kwargs.pop("force_download", False)
@@ -215,7 +112,7 @@ class PretrainedConfig:
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
config_file = pretrained_model_name_or_path
else:
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
configuration_file = cls.config_name
if os.path.isdir(pretrained_model_name_or_path):
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
@@ -286,58 +183,27 @@ class PretrainedConfig:
else:
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
return config_dict, kwargs
expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys())
expected_keys.remove("self")
@classmethod
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
"""
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
passed_keys = set(config_dict.keys())
Args:
config_dict (`Dict[str, Any]`):
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
kwargs (`Dict[str, Any]`):
Additional parameters from which to initialize the configuration object.
unused_kwargs = kwargs
for key in passed_keys - expected_keys:
unused_kwargs[key] = config_dict.pop(key)
Returns:
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
"""
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
# Those arguments may be passed along for our internal telemetry.
# We remove them so they don't appear in `return_unused_kwargs`.
if len(expected_keys - passed_keys) > 0:
logger.warn(
f"{expected_keys - passed_keys} was not found in config. "
f"Values will be initialized to default values."
)
config = cls(**config_dict)
model = cls(**config_dict)
to_remove = []
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
to_remove.append(key)
for key in to_remove:
kwargs.pop(key, None)
logger.info(f"Model config {config}")
if return_unused_kwargs:
return config, kwargs
return model, unused_kwargs
else:
return config
@classmethod
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
"""
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
Args:
json_file (`str` or `os.PathLike`):
Path to the JSON file containing the parameters.
Returns:
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
"""
config_dict = cls._dict_from_json_file(json_file)
return cls(**config_dict)
return model
@classmethod
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
@@ -351,38 +217,6 @@ class PretrainedConfig:
def __repr__(self):
return f"{self.__class__.__name__} {self.to_json_string()}"
def to_diff_dict(self) -> Dict[str, Any]:
"""
Removes all attributes from config which correspond to the default config attributes for better readability and
serializes to a Python dictionary.
Returns:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
"""
config_dict = self.to_dict()
# get the default config dict
default_config_dict = PretrainedConfig().to_dict()
# get class specific config dict
class_config_dict = self.__class__().to_dict()
serializable_config_dict = {}
# only serialize values that differ from the default config
for key, value in config_dict.items():
if (
key not in default_config_dict
or key == "diffusers_version"
or value != default_config_dict[key]
or (key in class_config_dict and value != class_config_dict[key])
):
serializable_config_dict[key] = value
self.dict_torch_dtype_to_str(serializable_config_dict)
return serializable_config_dict
def to_dict(self) -> Dict[str, Any]:
"""
Serializes this instance to a Python dictionary.
@@ -391,106 +225,29 @@ class PretrainedConfig:
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
"""
output = copy.deepcopy(self.__dict__)
if hasattr(self.__class__, "model_type"):
output["model_type"] = self.__class__.model_type
if "_auto_class" in output:
del output["_auto_class"]
# Transformers version when serializing the model
# Diffusion version when serializing the model
output["diffusers_version"] = __version__
self.dict_torch_dtype_to_str(output)
return output
def to_json_string(self, use_diff: bool = True) -> str:
def to_json_string(self) -> str:
"""
Serializes this instance to a JSON string.
Args:
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
is serialized to JSON string.
Returns:
`str`: String containing all the attributes that make up this configuration instance in JSON format.
"""
if use_diff is True:
config_dict = self.to_diff_dict()
else:
config_dict = self.to_dict()
config_dict = self._dict_to_save
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
"""
Save this instance to a JSON file.
Args:
json_file_path (`str` or `os.PathLike`):
Path to the JSON file in which this configuration instance's parameters will be saved.
use_diff (`bool`, *optional*, defaults to `True`):
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
is serialized to JSON file.
"""
with open(json_file_path, "w", encoding="utf-8") as writer:
writer.write(self.to_json_string(use_diff=use_diff))
def update(self, config_dict: Dict[str, Any]):
"""
Updates attributes of this class with attributes from `config_dict`.
Args:
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
def update_from_string(self, update_str: str):
"""
Updates attributes of this class with attributes from `update_str`.
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
The keys to change have to already exist in the config object.
Args:
update_str (`str`): String with attributes that should be updated for this class.
"""
d = dict(x.split("=") for x in update_str.split(","))
for k, v in d.items():
if not hasattr(self, k):
raise ValueError(f"key {k} isn't in the original config dict")
old_v = getattr(self, k)
if isinstance(old_v, bool):
if v.lower() in ["true", "1", "y", "yes"]:
v = True
elif v.lower() in ["false", "0", "n", "no"]:
v = False
else:
raise ValueError(f"can't derive true or false from {v} (key {k})")
elif isinstance(old_v, int):
v = int(v)
elif isinstance(old_v, float):
v = float(v)
elif not isinstance(old_v, str):
raise ValueError(
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
)
setattr(self, k, v)
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
"""
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
string, which can then be stored in the json format.
"""
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
for value in d.values():
if isinstance(value, dict):
self.dict_torch_dtype_to_str(value)
writer.write(self.to_json_string())

View File

@@ -33,10 +33,9 @@ from transformers.utils import (
is_offline_mode,
is_remote_url,
logging,
CONFIG_NAME,
)
from .configuration_utils import PretrainedConfig
WEIGHTS_NAME = "diffusion_model.pt"
@@ -135,7 +134,7 @@ class PreTrainedModel(torch.nn.Module):
Class attributes (overridden by derived classes):
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
- **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class
for this model architecture.
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
taking as arguments:
@@ -150,35 +149,16 @@ class PreTrainedModel(torch.nn.Module):
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
models, `pixel_values` for vision models and `input_values` for speech models).
"""
config_class = None
config_name = CONFIG_NAME
def __init__(self, config: PretrainedConfig):
def __init__(self):
super().__init__()
if not isinstance(config, PretrainedConfig):
raise ValueError(
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
"`PretrainedConfig`. To create a model from a pretrained model use "
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
)
# Save config and origin of the pretrained weights if given in model
self.config = config
self.name_or_path = config.name_or_path
@classmethod
def _from_config(cls, config, **kwargs):
"""
All context managers that the model should be initialized under go here.
"""
model = cls(config, **kwargs)
return model
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
save_function: Callable = torch.save,
push_to_hub: bool = False,
**kwargs,
):
"""
@@ -195,16 +175,6 @@ class PreTrainedModel(torch.nn.Module):
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method.
push_to_hub (`bool`, *optional*, defaults to `False`):
Whether or not to push your model to the Hugging Face model hub after saving it.
<Tip warning={true}>
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
folder. Pass along `temp_dir=True` to use a temporary directory instead.
</Tip>
kwargs:
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
@@ -218,11 +188,9 @@ class PreTrainedModel(torch.nn.Module):
model_to_save = self
# Attach architecture to the config
model_to_save.config.architectures = [model_to_save.__class__.__name__]
# Save the config
if is_main_process:
model_to_save.config.save_pretrained(save_directory)
model_to_save.save_config(save_directory)
# Save the model
state_dict = model_to_save.state_dict()
@@ -241,7 +209,7 @@ class PreTrainedModel(torch.nn.Module):
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
r"""
Instantiate a pretrained pytorch model from a pre-trained model configuration.
@@ -265,11 +233,11 @@ class PreTrainedModel(torch.nn.Module):
- A path to a *directory* containing model weights saved using
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
config (`Union[Config, str, os.PathLike]`, *optional*):
Can be either:
- an instance of a class derived from [`PretrainedConfig`],
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
- an instance of a class derived from [`Config`],
- a string or path valid as input to [`~Config.from_pretrained`].
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
be automatically loaded when:
@@ -327,7 +295,7 @@ class PreTrainedModel(torch.nn.Module):
underlying model's `__init__` method (we assume all relevant updates to the configuration have
already been done)
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that
corresponds to a configuration attribute will be used to override said attribute with the
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
will be passed to the underlying model's `__init__` function.
@@ -356,7 +324,6 @@ class PreTrainedModel(torch.nn.Module):
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
mirror = kwargs.pop("mirror", None)
from_pipeline = kwargs.pop("_from_pipeline", None)
from_auto_class = kwargs.pop("_from_auto", False)
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
@@ -367,7 +334,7 @@ class PreTrainedModel(torch.nn.Module):
# Load config if we don't provide a configuration
config_path = pretrained_model_name_or_path
config, model_kwargs = cls.config_class.from_pretrained(
model, unused_kwargs = cls.from_config(
config_path,
cache_dir=cache_dir,
return_unused_kwargs=True,
@@ -377,12 +344,9 @@ class PreTrainedModel(torch.nn.Module):
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
_from_auto=from_auto_class,
_from_pipeline=from_pipeline,
**kwargs,
)
model_kwargs = kwargs
model.register(name_or_path=pretrained_model_name_or_path)
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
# Load model
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
@@ -456,18 +420,8 @@ class PreTrainedModel(torch.nn.Module):
else:
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
state_dict = load_state_dict(resolved_archive_file)
# set dtype to instantiate the model under:
# 1. If torch_dtype is not None, we use that dtype
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
# weights entry - we assume all weights are of the same dtype
# we also may have config.torch_dtype available, but we won't rely on it till v5
config.name_or_path = pretrained_model_name_or_path
model = cls(config, *model_args, **model_kwargs)
# restore default dtype
state_dict = load_state_dict(resolved_archive_file)
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
model,
state_dict,

View File

@@ -16,4 +16,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from . import unet
from .unet import UNetModel

View File

@@ -22,27 +22,22 @@ from inspect import isfunction
from pathlib import Path
import torch
import torch.nn.functional as F
from torch import einsum, nn
from torch.cuda.amp import GradScaler, autocast
from torch.optim import Adam
from torch.utils import data
from einops import rearrange
from PIL import Image
from torchvision import transforms, utils
from torchvision import utils, transforms
from tqdm import tqdm
from ...modeling_utils import PreTrainedModel
from .configuration_unet import UNetConfig
from ..configuration_utils import Config
from ..modeling_utils import PreTrainedModel
from PIL import Image
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
def exists(x):
return x is not None
@@ -246,13 +241,30 @@ class Attention(nn.Module):
return self.to_out(out)
class UNetModel(PreTrainedModel):
config_class = UNetConfig
def __init__(self, config):
super().__init__(config)
class UNetModel(PreTrainedModel, Config):
def __init__(
self,
dim=64,
dim_mults=(1, 2, 4, 8),
init_dim=None,
out_dim=None,
channels=3,
with_time_emb=True,
resnet_block_groups=8,
learned_variance=False,
):
super().__init__()
self.register(
dim=dim,
dim_mults=dim_mults,
init_dim=init_dim,
out_dim=out_dim,
channels=channels,
with_time_emb=with_time_emb,
resnet_block_groups=resnet_block_groups,
learned_variance=learned_variance,
)
init_dim = None
out_dim = None
channels = 3
@@ -262,9 +274,9 @@ class UNetModel(PreTrainedModel):
# determine dimensions
dim_mults = config.dim_mults
dim = config.dim
self.channels = config.channels
dim_mults = dim_mults
dim = dim
self.channels = channels
init_dim = default(init_dim, dim // 3 * 2)
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
@@ -357,238 +369,6 @@ class UNetModel(PreTrainedModel):
return self.final_conv(x)
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module):
def __init__(
self,
denoise_fn,
*,
image_size,
channels=3,
timesteps=1000,
loss_type="l1",
objective="pred_noise",
beta_schedule="cosine",
):
super().__init__()
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
self.channels = channels
self.image_size = image_size
self.denoise_fn = denoise_fn
self.objective = objective
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# helper function to register buffer from float64 to float32
def register_buffer(name, val):
self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer(
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, x, t, clip_denoised: bool):
model_output = self.denoise_fn(x, t)
if self.objective == "pred_noise":
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
elif self.objective == "pred_x0":
x_start = model_output
else:
raise ValueError(f"unknown objective {self.objective}")
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return result
@torch.no_grad()
def p_sample_loop(self, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, batch_size=16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop((batch_size, channels, image_size, image_size))
@torch.no_grad()
def interpolate(self, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def p_losses(self, x_start, t, noise=None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = self.denoise_fn(x, t)
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x_start
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target)
return loss
def forward(self, img, *args, **kwargs):
b, _, h, w, device, img_size, = (
*img.shape,
img.device,
self.image_size,
)
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = normalize_to_neg_one_to_one(img)
return self.p_losses(img, t, *args, **kwargs)
# dataset classes
@@ -621,6 +401,7 @@ class Dataset(data.Dataset):
class Trainer(object):
def __init__(
self,
diffusion_model,

View File

@@ -1,6 +0,0 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all
from .configuration_unet import UNetConfig
from .modeling_unet import GaussianDiffusion, UNetModel

View File

@@ -1,45 +0,0 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# helpers functions
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
from ...configuration_utils import PretrainedConfig
class UNetConfig(PretrainedConfig):
model_type = "unet"
def __init__(
self,
dim=64,
dim_mults=(1, 2, 4, 8),
init_dim=None,
out_dim=None,
channels=3,
with_time_emb=True,
resnet_block_groups=8,
learned_variance=False,
**kwargs,
):
super().__init__(**kwargs)
self.dim = dim
self.dim_mults = dim_mults
self.init_dim = init_dim
self.out_dim = out_dim
self.channels = channels
self.with_time_emb = with_time_emb
self.resnet_block_groups = resnet_block_groups
self.learned_variance = learned_variance

View File

@@ -1,3 +1,7 @@
# flake8: noqa
# There's no way to ignore "F401 '...' imported but unused" warnings in this
# module, but to preserve other warnings. So, don't check this module at all.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -11,3 +15,5 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .gaussian import GaussianDiffusion

View File

@@ -0,0 +1,312 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from torch import nn
from inspect import isfunction
from tqdm import tqdm
from ..configuration_utils import Config
SAMPLING_CONFIG_NAME = "sampler_config.json"
def exists(x):
return x is not None
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def cycle(dl):
while True:
for data_dl in dl:
yield data_dl
def num_to_groups(num, divisor):
groups = num // divisor
remainder = num % divisor
arr = [divisor] * groups
if remainder > 0:
arr.append(remainder)
return arr
def normalize_to_neg_one_to_one(img):
return img * 2 - 1
def unnormalize_to_zero_to_one(t):
return (t + 1) * 0.5
# small helper modules
class EMA:
def __init__(self, beta):
super().__init__()
self.beta = beta
def update_model_average(self, ma_model, current_model):
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
old_weight, up_weight = ma_params.data, current_params.data
ma_params.data = self.update_average(old_weight, up_weight)
def update_average(self, old, new):
if old is None:
return new
return old * self.beta + (1 - self.beta) * new
# gaussian diffusion trainer class
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def noise_like(shape, device, repeat=False):
def repeat_noise():
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
def noise():
return torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()
def linear_beta_schedule(timesteps):
scale = 1000 / timesteps
beta_start = scale * 0.0001
beta_end = scale * 0.02
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
def cosine_beta_schedule(timesteps, s=0.008):
"""
cosine schedule
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
"""
steps = timesteps + 1
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
return torch.clip(betas, 0, 0.999)
class GaussianDiffusion(nn.Module, Config):
config_name = SAMPLING_CONFIG_NAME
def __init__(
self,
image_size,
channels=3,
timesteps=1000,
loss_type="l1",
objective="pred_noise",
beta_schedule="cosine",
):
super().__init__()
self.register(
image_size=image_size,
channels=channels,
timesteps=timesteps,
loss_type=loss_type,
objective=objective,
beta_schedule=beta_schedule,
)
self.channels = channels
self.image_size = image_size
self.objective = objective
if beta_schedule == "linear":
betas = linear_beta_schedule(timesteps)
elif beta_schedule == "cosine":
betas = cosine_beta_schedule(timesteps)
else:
raise ValueError(f"unknown beta schedule {beta_schedule}")
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, axis=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
(timesteps,) = betas.shape
self.num_timesteps = int(timesteps)
self.loss_type = loss_type
# helper function to register buffer from float64 to float32
def register_buffer(name, val):
self.register_buffer(name, val.to(torch.float32))
register_buffer("betas", betas)
register_buffer("alphas_cumprod", alphas_cumprod)
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer("posterior_variance", posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
register_buffer(
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
)
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def p_mean_variance(self, model, x, t, clip_denoised: bool):
model_output = model(x, t)
if self.objective == "pred_noise":
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
elif self.objective == "pred_x0":
x_start = model_output
else:
raise ValueError(f"unknown objective {self.objective}")
if clip_denoised:
x_start.clamp_(-1.0, 1.0)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
return model_mean, posterior_variance, posterior_log_variance
@torch.no_grad()
def p_sample(self, model, x, t, clip_denoised=True, repeat_noise=False):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
return result
@torch.no_grad()
def p_sample_loop(self, model, shape):
device = self.betas.device
b = shape[0]
img = torch.randn(shape, device=device)
for i in tqdm(
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
img = unnormalize_to_zero_to_one(img)
return img
@torch.no_grad()
def sample(self, model, batch_size=16):
image_size = self.image_size
channels = self.channels
return self.p_sample_loop(model, (batch_size, channels, image_size, image_size))
@torch.no_grad()
def interpolate(self, model, x1, x2, t=None, lam=0.5):
b, *_, device = *x1.shape, x1.device
t = default(t, self.num_timesteps - 1)
assert x1.shape == x2.shape
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
img = (1 - lam) * xt1 + lam * xt2
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long))
return img
def q_sample(self, x_start, t, noise=None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
@property
def loss_fn(self):
if self.loss_type == "l1":
return F.l1_loss
elif self.loss_type == "l2":
return F.mse_loss
else:
raise ValueError(f"invalid loss type {self.loss_type}")
def p_losses(self, model, x_start, t, noise=None):
b, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
x = self.q_sample(x_start=x_start, t=t, noise=noise)
model_out = model(x, t)
if self.objective == "pred_noise":
target = noise
elif self.objective == "pred_x0":
target = x_start
else:
raise ValueError(f"unknown objective {self.objective}")
loss = self.loss_fn(model_out, target)
return loss
def forward(self, model, img, *args, **kwargs):
b, _, h, w, device, img_size, = (
*img.shape,
img.device,
self.image_size,
)
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
img = normalize_to_neg_one_to_one(img)
return self.p_losses(model, img, t, *args, **kwargs)

View File

@@ -19,7 +19,7 @@ import unittest
import torch
from diffusers import UNetConfig, UNetModel
from diffusers import GaussianDiffusion, UNetModel
global_rng = random.Random()
@@ -42,7 +42,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None):
class ModelTesterMixin(unittest.TestCase):
@property
def dummy_input(self):
batch_size = 1
@@ -55,8 +54,7 @@ class ModelTesterMixin(unittest.TestCase):
return (noise, time_step)
def test_from_pretrained_save_pretrained(self):
config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
model = UNetModel(config)
model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_pretrained(tmpdirname)
@@ -75,3 +73,34 @@ class ModelTesterMixin(unittest.TestCase):
image = model(*self.dummy_input)
assert image is not None, "Make sure output is not None"
class SamplerTesterMixin(unittest.TestCase):
@property
def dummy_model(self):
return UNetModel.from_pretrained("fusing/ddpm_dummy")
def test_from_pretrained_save_pretrained(self):
sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1")
with tempfile.TemporaryDirectory() as tmpdirname:
sampler.save_config(tmpdirname)
new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False)
model = self.dummy_model
torch.manual_seed(0)
sampled_out = sampler.sample(model, batch_size=1)
torch.manual_seed(0)
sampled_out_new = new_sampler.sample(model, batch_size=1)
assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output"
def test_from_pretrained_hub(self):
sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy")
model = self.dummy_model
sampled_out = sampler.sample(model, batch_size=1)
assert sampled_out is not None, "Make sure output is not None"