mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
add first template for DDPM forward
This commit is contained in:
26
models/ddpm/run_ddpm.py
Executable file
26
models/ddpm/run_ddpm.py
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/usr/bin/env python3
|
||||
import torch
|
||||
|
||||
from diffusers import GaussianDiffusion, UNetConfig, UNetModel
|
||||
|
||||
|
||||
config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8))
|
||||
model = UNetModel(config)
|
||||
print(model.config)
|
||||
|
||||
model.save_pretrained("/home/patrick/diffusion_example")
|
||||
|
||||
import ipdb
|
||||
|
||||
|
||||
ipdb.set_trace()
|
||||
|
||||
diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2
|
||||
|
||||
training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1
|
||||
loss = diffusion(training_images)
|
||||
loss.backward()
|
||||
# after a lot of training
|
||||
|
||||
sampled_images = diffusion.sample(batch_size=4)
|
||||
sampled_images.shape # (4, 3, 128, 128)
|
||||
@@ -4,4 +4,5 @@
|
||||
|
||||
__version__ = "0.0.1"
|
||||
|
||||
from .models import UNetModel
|
||||
from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else
|
||||
from .models.unet import UNetConfig, UNetModel
|
||||
|
||||
496
src/diffusers/configuration_utils.py
Executable file
496
src/diffusers/configuration_utils.py
Executable file
@@ -0,0 +1,496 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Configuration base class and utilities."""
|
||||
|
||||
|
||||
import copy
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Any, Dict, Tuple, Union
|
||||
|
||||
from requests import HTTPError
|
||||
from transformers.utils import (
|
||||
CONFIG_NAME,
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
from . import __version__
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
||||
|
||||
|
||||
class PretrainedConfig:
|
||||
r"""
|
||||
Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as
|
||||
methods for loading/downloading/saving configurations.
|
||||
|
||||
"""
|
||||
model_type: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
# Name or path to the pretrained checkpoint
|
||||
self._name_or_path = str(kwargs.pop("name_or_path", ""))
|
||||
|
||||
# Drop the diffusers version info
|
||||
self.diffusers_version = kwargs.pop("diffusers_version", None)
|
||||
|
||||
@property
|
||||
def name_or_path(self) -> str:
|
||||
return getattr(self, "_name_or_path", None)
|
||||
|
||||
@name_or_path.setter
|
||||
def name_or_path(self, value):
|
||||
self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
|
||||
|
||||
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
||||
"""
|
||||
Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
|
||||
[`~PretrainedConfig.from_pretrained`] class method.
|
||||
|
||||
Args:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory where the configuration JSON file will be saved (will be created if it does not exist).
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# If we save using the predefined names, we can load using `from_pretrained`
|
||||
output_config_file = os.path.join(save_directory, CONFIG_NAME)
|
||||
|
||||
self.to_json_file(output_config_file, use_diff=True)
|
||||
logger.info(f"Configuration saved in {output_config_file}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig":
|
||||
r"""
|
||||
Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
This can be either:
|
||||
|
||||
- a string, the *model id* of a pretrained model configuration hosted inside a model repo on
|
||||
huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or
|
||||
namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- a path to a *directory* containing a configuration file saved using the
|
||||
[`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`.
|
||||
- a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`.
|
||||
cache_dir (`str` or `os.PathLike`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force to (re-)download the configuration files and override the cached versions if
|
||||
they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received file. Attempts to resume the download if such a file
|
||||
exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
||||
If `False`, then this function returns just the final configuration object.
|
||||
|
||||
If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a
|
||||
dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the
|
||||
part of `kwargs` which has not been used to update `config` and is otherwise ignored.
|
||||
kwargs (`Dict[str, Any]`, *optional*):
|
||||
The values in kwargs of any keys which are configuration attributes will be used to override the loaded
|
||||
values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled
|
||||
by the `return_unused_kwargs` keyword parameter.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from this pretrained model.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
# We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a
|
||||
# derived class: BertConfig
|
||||
config = BertConfig.from_pretrained(
|
||||
"bert-base-uncased"
|
||||
) # Download configuration from huggingface.co and cache.
|
||||
config = BertConfig.from_pretrained(
|
||||
"./test/saved_model/"
|
||||
) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')*
|
||||
config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json")
|
||||
config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False)
|
||||
assert config.output_attentions == True
|
||||
config, unused_kwargs = BertConfig.from_pretrained(
|
||||
"bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True
|
||||
)
|
||||
assert config.output_attentions == True
|
||||
assert unused_kwargs == {"foo": False}
|
||||
```"""
|
||||
config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
"""
|
||||
From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
|
||||
[`PretrainedConfig`] using `from_dict`.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`):
|
||||
The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
|
||||
|
||||
Returns:
|
||||
`Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object.
|
||||
|
||||
"""
|
||||
# Get config dict associated with the base config file
|
||||
config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def _get_config_dict(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
revision = kwargs.pop("revision", None)
|
||||
|
||||
user_agent = {"file_type": "config"}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
config_file = pretrained_model_name_or_path
|
||||
else:
|
||||
configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
|
||||
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
config_file = os.path.join(pretrained_model_name_or_path, configuration_file)
|
||||
else:
|
||||
config_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_config_file = cached_path(
|
||||
config_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on "
|
||||
"'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having "
|
||||
"permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass "
|
||||
"`use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this "
|
||||
f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for "
|
||||
"available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in"
|
||||
f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory"
|
||||
f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the"
|
||||
" library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a {configuration_file} file"
|
||||
)
|
||||
|
||||
try:
|
||||
# Load config dict
|
||||
config_dict = cls._dict_from_json_file(resolved_config_file)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
raise EnvironmentError(
|
||||
f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
|
||||
)
|
||||
|
||||
if resolved_config_file == config_file:
|
||||
logger.info(f"loading configuration file {config_file}")
|
||||
else:
|
||||
logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}")
|
||||
|
||||
return config_dict, kwargs
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`):
|
||||
Dictionary that will be used to instantiate the configuration object. Such a dictionary can be
|
||||
retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method.
|
||||
kwargs (`Dict[str, Any]`):
|
||||
Additional parameters from which to initialize the configuration object.
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from those parameters.
|
||||
"""
|
||||
return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
|
||||
# Those arguments may be passed along for our internal telemetry.
|
||||
# We remove them so they don't appear in `return_unused_kwargs`.
|
||||
|
||||
config = cls(**config_dict)
|
||||
|
||||
to_remove = []
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
to_remove.append(key)
|
||||
for key in to_remove:
|
||||
kwargs.pop(key, None)
|
||||
|
||||
logger.info(f"Model config {config}")
|
||||
if return_unused_kwargs:
|
||||
return config, kwargs
|
||||
else:
|
||||
return config
|
||||
|
||||
@classmethod
|
||||
def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig":
|
||||
"""
|
||||
Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters.
|
||||
|
||||
Args:
|
||||
json_file (`str` or `os.PathLike`):
|
||||
Path to the JSON file containing the parameters.
|
||||
|
||||
Returns:
|
||||
[`PretrainedConfig`]: The configuration object instantiated from that JSON file.
|
||||
|
||||
"""
|
||||
config_dict = cls._dict_from_json_file(json_file)
|
||||
return cls(**config_dict)
|
||||
|
||||
@classmethod
|
||||
def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]):
|
||||
with open(json_file, "r", encoding="utf-8") as reader:
|
||||
text = reader.read()
|
||||
return json.loads(text)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.__dict__ == other.__dict__
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__} {self.to_json_string()}"
|
||||
|
||||
def to_diff_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Removes all attributes from config which correspond to the default config attributes for better readability and
|
||||
serializes to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
config_dict = self.to_dict()
|
||||
|
||||
# get the default config dict
|
||||
default_config_dict = PretrainedConfig().to_dict()
|
||||
|
||||
# get class specific config dict
|
||||
class_config_dict = self.__class__().to_dict()
|
||||
|
||||
serializable_config_dict = {}
|
||||
|
||||
# only serialize values that differ from the default config
|
||||
for key, value in config_dict.items():
|
||||
if (
|
||||
key not in default_config_dict
|
||||
or key == "diffusers_version"
|
||||
or value != default_config_dict[key]
|
||||
or (key in class_config_dict and value != class_config_dict[key])
|
||||
):
|
||||
serializable_config_dict[key] = value
|
||||
|
||||
self.dict_torch_dtype_to_str(serializable_config_dict)
|
||||
|
||||
return serializable_config_dict
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Serializes this instance to a Python dictionary.
|
||||
|
||||
Returns:
|
||||
`Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
if hasattr(self.__class__, "model_type"):
|
||||
output["model_type"] = self.__class__.model_type
|
||||
if "_auto_class" in output:
|
||||
del output["_auto_class"]
|
||||
|
||||
# Transformers version when serializing the model
|
||||
output["diffusers_version"] = __version__
|
||||
|
||||
self.dict_torch_dtype_to_str(output)
|
||||
|
||||
return output
|
||||
|
||||
def to_json_string(self, use_diff: bool = True) -> str:
|
||||
"""
|
||||
Serializes this instance to a JSON string.
|
||||
|
||||
Args:
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
||||
is serialized to JSON string.
|
||||
|
||||
Returns:
|
||||
`str`: String containing all the attributes that make up this configuration instance in JSON format.
|
||||
"""
|
||||
if use_diff is True:
|
||||
config_dict = self.to_diff_dict()
|
||||
else:
|
||||
config_dict = self.to_dict()
|
||||
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
||||
|
||||
def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
|
||||
"""
|
||||
Save this instance to a JSON file.
|
||||
|
||||
Args:
|
||||
json_file_path (`str` or `os.PathLike`):
|
||||
Path to the JSON file in which this configuration instance's parameters will be saved.
|
||||
use_diff (`bool`, *optional*, defaults to `True`):
|
||||
If set to `True`, only the difference between the config instance and the default `PretrainedConfig()`
|
||||
is serialized to JSON file.
|
||||
"""
|
||||
with open(json_file_path, "w", encoding="utf-8") as writer:
|
||||
writer.write(self.to_json_string(use_diff=use_diff))
|
||||
|
||||
def update(self, config_dict: Dict[str, Any]):
|
||||
"""
|
||||
Updates attributes of this class with attributes from `config_dict`.
|
||||
|
||||
Args:
|
||||
config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
|
||||
"""
|
||||
for key, value in config_dict.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def update_from_string(self, update_str: str):
|
||||
"""
|
||||
Updates attributes of this class with attributes from `update_str`.
|
||||
|
||||
The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example:
|
||||
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
|
||||
|
||||
The keys to change have to already exist in the config object.
|
||||
|
||||
Args:
|
||||
update_str (`str`): String with attributes that should be updated for this class.
|
||||
|
||||
"""
|
||||
|
||||
d = dict(x.split("=") for x in update_str.split(","))
|
||||
for k, v in d.items():
|
||||
if not hasattr(self, k):
|
||||
raise ValueError(f"key {k} isn't in the original config dict")
|
||||
|
||||
old_v = getattr(self, k)
|
||||
if isinstance(old_v, bool):
|
||||
if v.lower() in ["true", "1", "y", "yes"]:
|
||||
v = True
|
||||
elif v.lower() in ["false", "0", "n", "no"]:
|
||||
v = False
|
||||
else:
|
||||
raise ValueError(f"can't derive true or false from {v} (key {k})")
|
||||
elif isinstance(old_v, int):
|
||||
v = int(v)
|
||||
elif isinstance(old_v, float):
|
||||
v = float(v)
|
||||
elif not isinstance(old_v, str):
|
||||
raise ValueError(
|
||||
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
|
||||
)
|
||||
|
||||
setattr(self, k, v)
|
||||
|
||||
def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
|
||||
converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
|
||||
string, which can then be stored in the json format.
|
||||
"""
|
||||
if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str):
|
||||
d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1]
|
||||
for value in d.values():
|
||||
if isinstance(value, dict):
|
||||
self.dict_torch_dtype_to_str(value)
|
||||
637
src/diffusers/modeling_utils.py
Normal file
637
src/diffusers/modeling_utils.py
Normal file
@@ -0,0 +1,637 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from torch import Tensor, device
|
||||
|
||||
from requests import HTTPError
|
||||
|
||||
# CHANGE to diffusers.utils
|
||||
from transformers.utils import (
|
||||
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
||||
EntryNotFoundError,
|
||||
RepositoryNotFoundError,
|
||||
RevisionNotFoundError,
|
||||
cached_path,
|
||||
hf_bucket_url,
|
||||
is_offline_mode,
|
||||
is_remote_url,
|
||||
logging,
|
||||
)
|
||||
|
||||
from .configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
WEIGHTS_NAME = "diffusion_model.pt"
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
def get_parameter_device(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).device
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].device
|
||||
|
||||
|
||||
def get_parameter_dtype(parameter: torch.nn.Module):
|
||||
try:
|
||||
return next(parameter.parameters()).dtype
|
||||
except StopIteration:
|
||||
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
||||
|
||||
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
||||
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
||||
return tuples
|
||||
|
||||
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
||||
first_tuple = next(gen)
|
||||
return first_tuple[1].dtype
|
||||
|
||||
|
||||
def load_state_dict(checkpoint_file: Union[str, os.PathLike]):
|
||||
"""
|
||||
Reads a PyTorch checkpoint file, returning properly formatted errors if they arise.
|
||||
"""
|
||||
try:
|
||||
return torch.load(checkpoint_file, map_location="cpu")
|
||||
except Exception as e:
|
||||
try:
|
||||
with open(checkpoint_file) as f:
|
||||
if f.read().startswith("version"):
|
||||
raise OSError(
|
||||
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
||||
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
||||
"you cloned."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
||||
"model. Make sure you have saved the model properly."
|
||||
) from e
|
||||
except (UnicodeDecodeError, ValueError):
|
||||
raise OSError(
|
||||
f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
|
||||
f"at '{checkpoint_file}'. "
|
||||
"If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
|
||||
)
|
||||
|
||||
|
||||
def _load_state_dict_into_model(model_to_load, state_dict):
|
||||
# Convert old format to new format if needed from a PyTorch state_dict
|
||||
# copy state_dict so _load_from_state_dict can modify it
|
||||
state_dict = state_dict.copy()
|
||||
error_msgs = []
|
||||
|
||||
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
||||
# so we need to apply the function recursively.
|
||||
def load(module: torch.nn.Module, prefix=""):
|
||||
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
||||
module._load_from_state_dict(*args)
|
||||
|
||||
for name, child in module._modules.items():
|
||||
if child is not None:
|
||||
load(child, prefix + name + ".")
|
||||
|
||||
load(model_to_load)
|
||||
|
||||
return error_msgs
|
||||
|
||||
|
||||
class PreTrainedModel(torch.nn.Module):
|
||||
r"""
|
||||
Base class for all models.
|
||||
|
||||
[`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
|
||||
downloading and saving models as well as a few methods common to all models to:
|
||||
|
||||
- resize the input embeddings,
|
||||
- prune heads in the self-attention heads.
|
||||
|
||||
Class attributes (overridden by derived classes):
|
||||
|
||||
- **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
|
||||
for this model architecture.
|
||||
- **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
|
||||
taking as arguments:
|
||||
|
||||
- **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
|
||||
- **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
|
||||
- **path** (`str`) -- A path to the TensorFlow checkpoint.
|
||||
|
||||
- **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
|
||||
classes of the same architecture adding modules on top of the base model.
|
||||
- **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
|
||||
- **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
|
||||
models, `pixel_values` for vision models and `input_values` for speech models).
|
||||
"""
|
||||
config_class = None
|
||||
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
if not isinstance(config, PretrainedConfig):
|
||||
raise ValueError(
|
||||
f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
|
||||
"`PretrainedConfig`. To create a model from a pretrained model use "
|
||||
f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
|
||||
)
|
||||
# Save config and origin of the pretrained weights if given in model
|
||||
self.config = config
|
||||
self.name_or_path = config.name_or_path
|
||||
|
||||
@classmethod
|
||||
def _from_config(cls, config, **kwargs):
|
||||
"""
|
||||
All context managers that the model should be initialized under go here.
|
||||
"""
|
||||
model = cls(config, **kwargs)
|
||||
|
||||
return model
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
save_function: Callable = torch.save,
|
||||
push_to_hub: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Save a model and its configuration file to a directory, so that it can be re-loaded using the
|
||||
`[`~PreTrainedModel.from_pretrained`]` class method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method.
|
||||
push_to_hub (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to push your model to the Hugging Face model hub after saving it.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`,
|
||||
which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing
|
||||
folder. Pass along `temp_dir=True` to use a temporary directory instead.
|
||||
|
||||
</Tip>
|
||||
|
||||
kwargs:
|
||||
Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = self
|
||||
|
||||
# Attach architecture to the config
|
||||
model_to_save.config.architectures = [model_to_save.__class__.__name__]
|
||||
|
||||
# Save the config
|
||||
if is_main_process:
|
||||
model_to_save.config.save_pretrained(save_directory)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process:
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME))
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs):
|
||||
r"""
|
||||
Instantiate a pretrained pytorch model from a pre-trained model configuration.
|
||||
|
||||
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
||||
the model, you should first set it back in training mode with `model.train()`.
|
||||
|
||||
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
||||
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
||||
task.
|
||||
|
||||
The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
|
||||
weights are discarded.
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a
|
||||
user or organization name, like `dbmdz/bert-base-german-cased`.
|
||||
- A path to a *directory* containing model weights saved using
|
||||
[`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
|
||||
|
||||
config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
|
||||
Can be either:
|
||||
|
||||
- an instance of a class derived from [`PretrainedConfig`],
|
||||
- a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
|
||||
|
||||
Configuration for the model to use instead of an automatically loaded configuration. Configuration can
|
||||
be automatically loaded when:
|
||||
|
||||
- The model is a model provided by the library (loaded with the *model id* string of a pretrained
|
||||
model).
|
||||
- The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
|
||||
save directory.
|
||||
- The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
|
||||
configuration JSON file named *config.json* is found in the directory.
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
from_tf (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a TensorFlow checkpoint save file (see docstring of
|
||||
`pretrained_model_name_or_path` argument).
|
||||
from_flax (`bool`, *optional*, defaults to `False`):
|
||||
Load the model weights from a Flax checkpoint save file (see docstring of
|
||||
`pretrained_model_name_or_path` argument).
|
||||
ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
|
||||
as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
|
||||
checkpoint with 3 labels).
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
output_loading_info(`bool`, *optional*, defaults to `False`):
|
||||
Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `transformers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
|
||||
kwargs (remaining dictionary of keyword arguments, *optional*):
|
||||
Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
|
||||
`output_attentions=True`). Behaves differently depending on whether a `config` is provided or
|
||||
automatically loaded:
|
||||
|
||||
- If a configuration is provided with `config`, `**kwargs` will be directly passed to the
|
||||
underlying model's `__init__` method (we assume all relevant updates to the configuration have
|
||||
already been done)
|
||||
- If a configuration is not provided, `kwargs` will be first passed to the configuration class
|
||||
initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
|
||||
corresponds to a configuration attribute will be used to override said attribute with the
|
||||
supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
|
||||
will be passed to the underlying model's `__init__` function.
|
||||
|
||||
<Tip>
|
||||
|
||||
Passing `use_auth_token=True`` is required when you want to use a private model.
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
|
||||
use this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
|
||||
"""
|
||||
cache_dir = kwargs.pop("cache_dir", None)
|
||||
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
output_loading_info = kwargs.pop("output_loading_info", False)
|
||||
local_files_only = kwargs.pop("local_files_only", False)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
mirror = kwargs.pop("mirror", None)
|
||||
from_pipeline = kwargs.pop("_from_pipeline", None)
|
||||
from_auto_class = kwargs.pop("_from_auto", False)
|
||||
|
||||
user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
|
||||
|
||||
if is_offline_mode() and not local_files_only:
|
||||
logger.info("Offline mode: forcing local_files_only=True")
|
||||
local_files_only = True
|
||||
|
||||
# Load config if we don't provide a configuration
|
||||
config_path = pretrained_model_name_or_path
|
||||
config, model_kwargs = cls.config_class.from_pretrained(
|
||||
config_path,
|
||||
cache_dir=cache_dir,
|
||||
return_unused_kwargs=True,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
_from_auto=from_auto_class,
|
||||
_from_pipeline=from_pipeline,
|
||||
**kwargs,
|
||||
)
|
||||
model_kwargs = kwargs
|
||||
|
||||
# This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the
|
||||
# Load model
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)):
|
||||
# Load from a PyTorch checkpoint
|
||||
archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
|
||||
archive_file = pretrained_model_name_or_path
|
||||
else:
|
||||
filename = WEIGHTS_NAME
|
||||
|
||||
archive_file = hf_bucket_url(
|
||||
pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror
|
||||
)
|
||||
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
resolved_archive_file = cached_path(
|
||||
archive_file,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login` and pass `use_auth_token=True`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.")
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {WEIGHTS_NAME} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {WEIGHTS_NAME}"
|
||||
)
|
||||
|
||||
if resolved_archive_file == archive_file:
|
||||
logger.info(f"loading weights file {archive_file}")
|
||||
else:
|
||||
logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}")
|
||||
|
||||
state_dict = load_state_dict(resolved_archive_file)
|
||||
# set dtype to instantiate the model under:
|
||||
# 1. If torch_dtype is not None, we use that dtype
|
||||
# 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
|
||||
# weights entry - we assume all weights are of the same dtype
|
||||
# we also may have config.torch_dtype available, but we won't rely on it till v5
|
||||
|
||||
config.name_or_path = pretrained_model_name_or_path
|
||||
|
||||
model = cls(config, *model_args, **model_kwargs)
|
||||
|
||||
# restore default dtype
|
||||
model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model(
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
||||
)
|
||||
|
||||
# Set model in evaluation mode to deactivate DropOut modules by default
|
||||
model.eval()
|
||||
|
||||
if output_loading_info:
|
||||
loading_info = {
|
||||
"missing_keys": missing_keys,
|
||||
"unexpected_keys": unexpected_keys,
|
||||
"mismatched_keys": mismatched_keys,
|
||||
"error_msgs": error_msgs,
|
||||
}
|
||||
return model, loading_info
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
model,
|
||||
state_dict,
|
||||
resolved_archive_file,
|
||||
pretrained_model_name_or_path,
|
||||
ignore_mismatched_sizes=False,
|
||||
):
|
||||
# Retrieve missing & unexpected_keys
|
||||
model_state_dict = model.state_dict()
|
||||
loaded_keys = [k for k in state_dict.keys()]
|
||||
|
||||
expected_keys = list(model_state_dict.keys())
|
||||
|
||||
original_loaded_keys = loaded_keys
|
||||
|
||||
missing_keys = list(set(expected_keys) - set(loaded_keys))
|
||||
unexpected_keys = list(set(loaded_keys) - set(expected_keys))
|
||||
|
||||
# Make sure we are able to load base models as well as derived models (with heads)
|
||||
model_to_load = model
|
||||
|
||||
def _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
):
|
||||
mismatched_keys = []
|
||||
if ignore_mismatched_sizes:
|
||||
for checkpoint_key in loaded_keys:
|
||||
model_key = checkpoint_key
|
||||
|
||||
if (
|
||||
model_key in model_state_dict
|
||||
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
||||
):
|
||||
mismatched_keys.append(
|
||||
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
||||
)
|
||||
del state_dict[checkpoint_key]
|
||||
return mismatched_keys
|
||||
|
||||
if state_dict is not None:
|
||||
# Whole checkpoint
|
||||
mismatched_keys = _find_mismatched_keys(
|
||||
state_dict,
|
||||
model_state_dict,
|
||||
original_loaded_keys,
|
||||
ignore_mismatched_sizes,
|
||||
)
|
||||
error_msgs = _load_state_dict_into_model(model_to_load, state_dict)
|
||||
|
||||
if len(error_msgs) > 0:
|
||||
error_msg = "\n\t".join(error_msgs)
|
||||
if "size mismatch" in error_msg:
|
||||
error_msg += (
|
||||
"\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
|
||||
)
|
||||
raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
|
||||
|
||||
if len(unexpected_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
|
||||
f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
|
||||
f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
|
||||
" with another architecture (e.g. initializing a BertForSequenceClassification model from a"
|
||||
" BertForPreTraining model).\n- This IS NOT expected if you are initializing"
|
||||
f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
|
||||
" (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
|
||||
)
|
||||
else:
|
||||
logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n")
|
||||
if len(missing_keys) > 0:
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
|
||||
" TRAIN this model on a down-stream task to be able to use it for predictions and inference."
|
||||
)
|
||||
elif len(mismatched_keys) == 0:
|
||||
logger.info(
|
||||
f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint"
|
||||
f" was trained on, you can already use {model.__class__.__name__} for predictions without further"
|
||||
" training."
|
||||
)
|
||||
if len(mismatched_keys) > 0:
|
||||
mismatched_warning = "\n".join(
|
||||
[
|
||||
f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
|
||||
for key, shape1, shape2 in mismatched_keys
|
||||
]
|
||||
)
|
||||
logger.warning(
|
||||
f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
|
||||
f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
|
||||
f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
|
||||
" to use it for predictions and inference."
|
||||
)
|
||||
|
||||
return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
|
||||
|
||||
@property
|
||||
def device(self) -> device:
|
||||
"""
|
||||
`torch.device`: The device on which the module is (assuming that all the module parameters are on the same
|
||||
device).
|
||||
"""
|
||||
return get_parameter_device(self)
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
"""
|
||||
`torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
|
||||
"""
|
||||
return get_parameter_dtype(self)
|
||||
|
||||
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
||||
"""
|
||||
Get number of (optionally, trainable or non-embeddings) parameters in the module.
|
||||
|
||||
Args:
|
||||
only_trainable (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return only the number of trainable parameters
|
||||
|
||||
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to return only the number of non-embeddings parameters
|
||||
|
||||
Returns:
|
||||
`int`: The number of parameters.
|
||||
"""
|
||||
|
||||
if exclude_embeddings:
|
||||
embedding_param_names = [
|
||||
f"{name}.weight"
|
||||
for name, module_type in self.named_modules()
|
||||
if isinstance(module_type, torch.nn.Embedding)
|
||||
]
|
||||
non_embedding_parameters = [
|
||||
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
||||
]
|
||||
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
||||
else:
|
||||
return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable)
|
||||
@@ -1,5 +1,19 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
# module, but to preserve other warnings. So, don't check this module at all.
|
||||
|
||||
from .unet import UNetModel
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from . import unet
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
class UNetModel:
|
||||
def __init__(self, config):
|
||||
self.config = config
|
||||
print("I can diffuse!")
|
||||
6
src/diffusers/models/unet/__init__.py
Normal file
6
src/diffusers/models/unet/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
# flake8: noqa
|
||||
# There's no way to ignore "F401 '...' imported but unused" warnings in this
|
||||
# module, but to preserve other warnings. So, don't check this module at all
|
||||
|
||||
from .configuration_unet import UNetConfig
|
||||
from .modeling_unet import GaussianDiffusion, UNetModel
|
||||
45
src/diffusers/models/unet/configuration_unet.py
Normal file
45
src/diffusers/models/unet/configuration_unet.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# helpers functions
|
||||
|
||||
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
from ...configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class UNetConfig(PretrainedConfig):
|
||||
model_type = "unet"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim=64,
|
||||
dim_mults=(1, 2, 4, 8),
|
||||
init_dim=None,
|
||||
out_dim=None,
|
||||
channels=3,
|
||||
with_time_emb=True,
|
||||
resnet_block_groups=8,
|
||||
learned_variance=False,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.dim = dim
|
||||
self.dim_mults = dim_mults
|
||||
self.init_dim = init_dim
|
||||
self.out_dim = out_dim
|
||||
self.channels = channels
|
||||
self.with_time_emb = with_time_emb
|
||||
self.resnet_block_groups = resnet_block_groups
|
||||
self.learned_variance = learned_variance
|
||||
728
src/diffusers/models/unet/modeling_unet.py
Normal file
728
src/diffusers/models/unet/modeling_unet.py
Normal file
@@ -0,0 +1,728 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
# limitations under the License.
|
||||
|
||||
# helpers functions
|
||||
|
||||
import copy
|
||||
import math
|
||||
from functools import partial
|
||||
from inspect import isfunction
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import einsum, nn
|
||||
from torch.cuda.amp import GradScaler, autocast
|
||||
from torch.optim import Adam
|
||||
from torch.utils import data
|
||||
|
||||
from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision import transforms, utils
|
||||
from tqdm import tqdm
|
||||
|
||||
from ...modeling_utils import PreTrainedModel
|
||||
from .configuration_unet import UNetConfig
|
||||
|
||||
|
||||
# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def exists(x):
|
||||
return x is not None
|
||||
|
||||
|
||||
def default(val, d):
|
||||
if exists(val):
|
||||
return val
|
||||
return d() if isfunction(d) else d
|
||||
|
||||
|
||||
def cycle(dl):
|
||||
while True:
|
||||
for data_dl in dl:
|
||||
yield data_dl
|
||||
|
||||
|
||||
def num_to_groups(num, divisor):
|
||||
groups = num // divisor
|
||||
remainder = num % divisor
|
||||
arr = [divisor] * groups
|
||||
if remainder > 0:
|
||||
arr.append(remainder)
|
||||
return arr
|
||||
|
||||
|
||||
def normalize_to_neg_one_to_one(img):
|
||||
return img * 2 - 1
|
||||
|
||||
|
||||
def unnormalize_to_zero_to_one(t):
|
||||
return (t + 1) * 0.5
|
||||
|
||||
|
||||
# small helper modules
|
||||
|
||||
|
||||
class EMA:
|
||||
def __init__(self, beta):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
|
||||
def update_model_average(self, ma_model, current_model):
|
||||
for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
|
||||
old_weight, up_weight = ma_params.data, current_params.data
|
||||
ma_params.data = self.update_average(old_weight, up_weight)
|
||||
|
||||
def update_average(self, old, new):
|
||||
if old is None:
|
||||
return new
|
||||
return old * self.beta + (1 - self.beta) * new
|
||||
|
||||
|
||||
class Residual(nn.Module):
|
||||
def __init__(self, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
|
||||
def forward(self, x, *args, **kwargs):
|
||||
return self.fn(x, *args, **kwargs) + x
|
||||
|
||||
|
||||
class SinusoidalPosEmb(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
||||
emb = x[:, None] * emb[None, :]
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
def Upsample(dim):
|
||||
return nn.ConvTranspose2d(dim, dim, 4, 2, 1)
|
||||
|
||||
|
||||
def Downsample(dim):
|
||||
return nn.Conv2d(dim, dim, 4, 2, 1)
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, dim, eps=1e-5):
|
||||
super().__init__()
|
||||
self.eps = eps
|
||||
self.g = nn.Parameter(torch.ones(1, dim, 1, 1))
|
||||
self.b = nn.Parameter(torch.zeros(1, dim, 1, 1))
|
||||
|
||||
def forward(self, x):
|
||||
var = torch.var(x, dim=1, unbiased=False, keepdim=True)
|
||||
mean = torch.mean(x, dim=1, keepdim=True)
|
||||
return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
|
||||
|
||||
|
||||
class PreNorm(nn.Module):
|
||||
def __init__(self, dim, fn):
|
||||
super().__init__()
|
||||
self.fn = fn
|
||||
self.norm = LayerNorm(dim)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.norm(x)
|
||||
return self.fn(x)
|
||||
|
||||
|
||||
# building block modules
|
||||
|
||||
|
||||
class Block(nn.Module):
|
||||
def __init__(self, dim, dim_out, groups=8):
|
||||
super().__init__()
|
||||
self.proj = nn.Conv2d(dim, dim_out, 3, padding=1)
|
||||
self.norm = nn.GroupNorm(groups, dim_out)
|
||||
self.act = nn.SiLU()
|
||||
|
||||
def forward(self, x, scale_shift=None):
|
||||
x = self.proj(x)
|
||||
x = self.norm(x)
|
||||
|
||||
if exists(scale_shift):
|
||||
scale, shift = scale_shift
|
||||
x = x * (scale + 1) + shift
|
||||
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class ResnetBlock(nn.Module):
|
||||
def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8):
|
||||
super().__init__()
|
||||
self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None
|
||||
|
||||
self.block1 = Block(dim, dim_out, groups=groups)
|
||||
self.block2 = Block(dim_out, dim_out, groups=groups)
|
||||
self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity()
|
||||
|
||||
def forward(self, x, time_emb=None):
|
||||
|
||||
scale_shift = None
|
||||
if exists(self.mlp) and exists(time_emb):
|
||||
time_emb = self.mlp(time_emb)
|
||||
time_emb = rearrange(time_emb, "b c -> b c 1 1")
|
||||
scale_shift = time_emb.chunk(2, dim=1)
|
||||
|
||||
h = self.block1(x, scale_shift=scale_shift)
|
||||
|
||||
h = self.block2(h)
|
||||
return h + self.res_conv(x)
|
||||
|
||||
|
||||
class LinearAttention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
|
||||
self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim))
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x).chunk(3, dim=1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)
|
||||
|
||||
q = q.softmax(dim=-2)
|
||||
k = k.softmax(dim=-1)
|
||||
|
||||
q = q * self.scale
|
||||
context = torch.einsum("b h d n, b h e n -> b h d e", k, v)
|
||||
|
||||
out = torch.einsum("b h d e, b h d n -> b h e n", context, q)
|
||||
out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, heads=4, dim_head=32):
|
||||
super().__init__()
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
hidden_dim = dim_head * heads
|
||||
self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
|
||||
self.to_out = nn.Conv2d(hidden_dim, dim, 1)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
qkv = self.to_qkv(x).chunk(3, dim=1)
|
||||
q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv)
|
||||
q = q * self.scale
|
||||
|
||||
sim = einsum("b h d i, b h d j -> b h i j", q, k)
|
||||
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
||||
attn = sim.softmax(dim=-1)
|
||||
|
||||
out = einsum("b h i j, b h d j -> b h i d", attn, v)
|
||||
out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w)
|
||||
return self.to_out(out)
|
||||
|
||||
|
||||
class UNetModel(PreTrainedModel):
|
||||
|
||||
config_class = UNetConfig
|
||||
|
||||
def __init__(self, config):
|
||||
super().__init__(config)
|
||||
|
||||
init_dim = None
|
||||
out_dim = None
|
||||
channels = 3
|
||||
with_time_emb = True
|
||||
resnet_block_groups = 8
|
||||
learned_variance = False
|
||||
|
||||
# determine dimensions
|
||||
|
||||
dim_mults = config.dim_mults
|
||||
dim = config.dim
|
||||
self.channels = config.channels
|
||||
|
||||
init_dim = default(init_dim, dim // 3 * 2)
|
||||
self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3)
|
||||
|
||||
dims = [init_dim, *map(lambda m: dim * m, dim_mults)]
|
||||
in_out = list(zip(dims[:-1], dims[1:]))
|
||||
|
||||
block_klass = partial(ResnetBlock, groups=resnet_block_groups)
|
||||
|
||||
# time embeddings
|
||||
|
||||
if with_time_emb:
|
||||
time_dim = dim * 4
|
||||
self.time_mlp = nn.Sequential(
|
||||
SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim)
|
||||
)
|
||||
else:
|
||||
time_dim = None
|
||||
self.time_mlp = None
|
||||
|
||||
# layers
|
||||
|
||||
self.downs = nn.ModuleList([])
|
||||
self.ups = nn.ModuleList([])
|
||||
num_resolutions = len(in_out)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(in_out):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.downs.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
block_klass(dim_in, dim_out, time_emb_dim=time_dim),
|
||||
block_klass(dim_out, dim_out, time_emb_dim=time_dim),
|
||||
Residual(PreNorm(dim_out, LinearAttention(dim_out))),
|
||||
Downsample(dim_out) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
mid_dim = dims[-1]
|
||||
self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
||||
self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim)))
|
||||
self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim)
|
||||
|
||||
for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
|
||||
is_last = ind >= (num_resolutions - 1)
|
||||
|
||||
self.ups.append(
|
||||
nn.ModuleList(
|
||||
[
|
||||
block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim),
|
||||
block_klass(dim_in, dim_in, time_emb_dim=time_dim),
|
||||
Residual(PreNorm(dim_in, LinearAttention(dim_in))),
|
||||
Upsample(dim_in) if not is_last else nn.Identity(),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
default_out_dim = channels * (1 if not learned_variance else 2)
|
||||
self.out_dim = default(out_dim, default_out_dim)
|
||||
|
||||
self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1))
|
||||
|
||||
def forward(self, x, time):
|
||||
x = self.init_conv(x)
|
||||
|
||||
t = self.time_mlp(time) if exists(self.time_mlp) else None
|
||||
|
||||
h = []
|
||||
|
||||
for block1, block2, attn, downsample in self.downs:
|
||||
x = block1(x, t)
|
||||
x = block2(x, t)
|
||||
x = attn(x)
|
||||
h.append(x)
|
||||
x = downsample(x)
|
||||
|
||||
x = self.mid_block1(x, t)
|
||||
x = self.mid_attn(x)
|
||||
x = self.mid_block2(x, t)
|
||||
|
||||
for block1, block2, attn, upsample in self.ups:
|
||||
x = torch.cat((x, h.pop()), dim=1)
|
||||
x = block1(x, t)
|
||||
x = block2(x, t)
|
||||
x = attn(x)
|
||||
x = upsample(x)
|
||||
|
||||
return self.final_conv(x)
|
||||
|
||||
|
||||
# gaussian diffusion trainer class
|
||||
|
||||
|
||||
def extract(a, t, x_shape):
|
||||
b, *_ = t.shape
|
||||
out = a.gather(-1, t)
|
||||
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
|
||||
|
||||
|
||||
def noise_like(shape, device, repeat=False):
|
||||
def repeat_noise():
|
||||
return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||
|
||||
def noise():
|
||||
return torch.randn(shape, device=device)
|
||||
|
||||
return repeat_noise() if repeat else noise()
|
||||
|
||||
|
||||
def linear_beta_schedule(timesteps):
|
||||
scale = 1000 / timesteps
|
||||
beta_start = scale * 0.0001
|
||||
beta_end = scale * 0.02
|
||||
return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64)
|
||||
|
||||
|
||||
def cosine_beta_schedule(timesteps, s=0.008):
|
||||
"""
|
||||
cosine schedule
|
||||
as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
|
||||
"""
|
||||
steps = timesteps + 1
|
||||
x = torch.linspace(0, timesteps, steps, dtype=torch.float64)
|
||||
alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2
|
||||
alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
|
||||
betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
|
||||
return torch.clip(betas, 0, 0.999)
|
||||
|
||||
|
||||
class GaussianDiffusion(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
denoise_fn,
|
||||
*,
|
||||
image_size,
|
||||
channels=3,
|
||||
timesteps=1000,
|
||||
loss_type="l1",
|
||||
objective="pred_noise",
|
||||
beta_schedule="cosine",
|
||||
):
|
||||
super().__init__()
|
||||
assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim)
|
||||
|
||||
self.channels = channels
|
||||
self.image_size = image_size
|
||||
self.denoise_fn = denoise_fn
|
||||
self.objective = objective
|
||||
|
||||
if beta_schedule == "linear":
|
||||
betas = linear_beta_schedule(timesteps)
|
||||
elif beta_schedule == "cosine":
|
||||
betas = cosine_beta_schedule(timesteps)
|
||||
else:
|
||||
raise ValueError(f"unknown beta schedule {beta_schedule}")
|
||||
|
||||
alphas = 1.0 - betas
|
||||
alphas_cumprod = torch.cumprod(alphas, axis=0)
|
||||
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0)
|
||||
|
||||
(timesteps,) = betas.shape
|
||||
self.num_timesteps = int(timesteps)
|
||||
self.loss_type = loss_type
|
||||
|
||||
# helper function to register buffer from float64 to float32
|
||||
|
||||
def register_buffer(name, val):
|
||||
self.register_buffer(name, val.to(torch.float32))
|
||||
|
||||
register_buffer("betas", betas)
|
||||
register_buffer("alphas_cumprod", alphas_cumprod)
|
||||
register_buffer("alphas_cumprod_prev", alphas_cumprod_prev)
|
||||
|
||||
# calculations for diffusion q(x_t | x_{t-1}) and others
|
||||
|
||||
register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
|
||||
register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod))
|
||||
register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod))
|
||||
register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod))
|
||||
register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1))
|
||||
|
||||
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
||||
|
||||
posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod)
|
||||
|
||||
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
|
||||
|
||||
register_buffer("posterior_variance", posterior_variance)
|
||||
|
||||
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
||||
|
||||
register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20)))
|
||||
register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod))
|
||||
register_buffer(
|
||||
"posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod)
|
||||
)
|
||||
|
||||
def predict_start_from_noise(self, x_t, t, noise):
|
||||
return (
|
||||
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
||||
- extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
|
||||
)
|
||||
|
||||
def q_posterior(self, x_start, x_t, t):
|
||||
posterior_mean = (
|
||||
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
||||
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
||||
)
|
||||
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
|
||||
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
|
||||
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
||||
|
||||
def p_mean_variance(self, x, t, clip_denoised: bool):
|
||||
model_output = self.denoise_fn(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
x_start = self.predict_start_from_noise(x, t=t, noise=model_output)
|
||||
elif self.objective == "pred_x0":
|
||||
x_start = model_output
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
if clip_denoised:
|
||||
x_start.clamp_(-1.0, 1.0)
|
||||
|
||||
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t)
|
||||
return model_mean, posterior_variance, posterior_log_variance
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample(self, x, t, clip_denoised=True, repeat_noise=False):
|
||||
b, *_, device = *x.shape, x.device
|
||||
model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised)
|
||||
noise = noise_like(x.shape, device, repeat_noise)
|
||||
# no noise when t == 0
|
||||
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
|
||||
result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def p_sample_loop(self, shape):
|
||||
device = self.betas.device
|
||||
|
||||
b = shape[0]
|
||||
img = torch.randn(shape, device=device)
|
||||
|
||||
for i in tqdm(
|
||||
reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps
|
||||
):
|
||||
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
img = unnormalize_to_zero_to_one(img)
|
||||
return img
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self, batch_size=16):
|
||||
image_size = self.image_size
|
||||
channels = self.channels
|
||||
return self.p_sample_loop((batch_size, channels, image_size, image_size))
|
||||
|
||||
@torch.no_grad()
|
||||
def interpolate(self, x1, x2, t=None, lam=0.5):
|
||||
b, *_, device = *x1.shape, x1.device
|
||||
t = default(t, self.num_timesteps - 1)
|
||||
|
||||
assert x1.shape == x2.shape
|
||||
|
||||
t_batched = torch.stack([torch.tensor(t, device=device)] * b)
|
||||
xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2))
|
||||
|
||||
img = (1 - lam) * xt1 + lam * xt2
|
||||
for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t):
|
||||
img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long))
|
||||
|
||||
return img
|
||||
|
||||
def q_sample(self, x_start, t, noise=None):
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
return (
|
||||
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
||||
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
||||
)
|
||||
|
||||
@property
|
||||
def loss_fn(self):
|
||||
if self.loss_type == "l1":
|
||||
return F.l1_loss
|
||||
elif self.loss_type == "l2":
|
||||
return F.mse_loss
|
||||
else:
|
||||
raise ValueError(f"invalid loss type {self.loss_type}")
|
||||
|
||||
def p_losses(self, x_start, t, noise=None):
|
||||
b, c, h, w = x_start.shape
|
||||
noise = default(noise, lambda: torch.randn_like(x_start))
|
||||
|
||||
x = self.q_sample(x_start=x_start, t=t, noise=noise)
|
||||
model_out = self.denoise_fn(x, t)
|
||||
|
||||
if self.objective == "pred_noise":
|
||||
target = noise
|
||||
elif self.objective == "pred_x0":
|
||||
target = x_start
|
||||
else:
|
||||
raise ValueError(f"unknown objective {self.objective}")
|
||||
|
||||
loss = self.loss_fn(model_out, target)
|
||||
return loss
|
||||
|
||||
def forward(self, img, *args, **kwargs):
|
||||
b, _, h, w, device, img_size, = (
|
||||
*img.shape,
|
||||
img.device,
|
||||
self.image_size,
|
||||
)
|
||||
assert h == img_size and w == img_size, f"height and width of image must be {img_size}"
|
||||
t = torch.randint(0, self.num_timesteps, (b,), device=device).long()
|
||||
|
||||
img = normalize_to_neg_one_to_one(img)
|
||||
return self.p_losses(img, t, *args, **kwargs)
|
||||
|
||||
|
||||
# dataset classes
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]):
|
||||
super().__init__()
|
||||
self.folder = folder
|
||||
self.image_size = image_size
|
||||
self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")]
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(image_size),
|
||||
transforms.RandomHorizontalFlip(),
|
||||
transforms.CenterCrop(image_size),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.paths)
|
||||
|
||||
def __getitem__(self, index):
|
||||
path = self.paths[index]
|
||||
img = Image.open(path)
|
||||
return self.transform(img)
|
||||
|
||||
|
||||
# trainer class
|
||||
|
||||
|
||||
class Trainer(object):
|
||||
def __init__(
|
||||
self,
|
||||
diffusion_model,
|
||||
folder,
|
||||
*,
|
||||
ema_decay=0.995,
|
||||
image_size=128,
|
||||
train_batch_size=32,
|
||||
train_lr=1e-4,
|
||||
train_num_steps=100000,
|
||||
gradient_accumulate_every=2,
|
||||
amp=False,
|
||||
step_start_ema=2000,
|
||||
update_ema_every=10,
|
||||
save_and_sample_every=1000,
|
||||
results_folder="./results",
|
||||
):
|
||||
super().__init__()
|
||||
self.model = diffusion_model
|
||||
self.ema = EMA(ema_decay)
|
||||
self.ema_model = copy.deepcopy(self.model)
|
||||
self.update_ema_every = update_ema_every
|
||||
|
||||
self.step_start_ema = step_start_ema
|
||||
self.save_and_sample_every = save_and_sample_every
|
||||
|
||||
self.batch_size = train_batch_size
|
||||
self.image_size = diffusion_model.image_size
|
||||
self.gradient_accumulate_every = gradient_accumulate_every
|
||||
self.train_num_steps = train_num_steps
|
||||
|
||||
self.ds = Dataset(folder, image_size)
|
||||
self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True))
|
||||
self.opt = Adam(diffusion_model.parameters(), lr=train_lr)
|
||||
|
||||
self.step = 0
|
||||
|
||||
self.amp = amp
|
||||
self.scaler = GradScaler(enabled=amp)
|
||||
|
||||
self.results_folder = Path(results_folder)
|
||||
self.results_folder.mkdir(exist_ok=True)
|
||||
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
self.ema_model.load_state_dict(self.model.state_dict())
|
||||
|
||||
def step_ema(self):
|
||||
if self.step < self.step_start_ema:
|
||||
self.reset_parameters()
|
||||
return
|
||||
self.ema.update_model_average(self.ema_model, self.model)
|
||||
|
||||
def save(self, milestone):
|
||||
data = {
|
||||
"step": self.step,
|
||||
"model": self.model.state_dict(),
|
||||
"ema": self.ema_model.state_dict(),
|
||||
"scaler": self.scaler.state_dict(),
|
||||
}
|
||||
torch.save(data, str(self.results_folder / f"model-{milestone}.pt"))
|
||||
|
||||
def load(self, milestone):
|
||||
data = torch.load(str(self.results_folder / f"model-{milestone}.pt"))
|
||||
|
||||
self.step = data["step"]
|
||||
self.model.load_state_dict(data["model"])
|
||||
self.ema_model.load_state_dict(data["ema"])
|
||||
self.scaler.load_state_dict(data["scaler"])
|
||||
|
||||
def train(self):
|
||||
with tqdm(initial=self.step, total=self.train_num_steps) as pbar:
|
||||
|
||||
while self.step < self.train_num_steps:
|
||||
for i in range(self.gradient_accumulate_every):
|
||||
data = next(self.dl).cuda()
|
||||
|
||||
with autocast(enabled=self.amp):
|
||||
loss = self.model(data)
|
||||
self.scaler.scale(loss / self.gradient_accumulate_every).backward()
|
||||
|
||||
pbar.set_description(f"loss: {loss.item():.4f}")
|
||||
|
||||
self.scaler.step(self.opt)
|
||||
self.scaler.update()
|
||||
self.opt.zero_grad()
|
||||
|
||||
if self.step % self.update_ema_every == 0:
|
||||
self.step_ema()
|
||||
|
||||
if self.step != 0 and self.step % self.save_and_sample_every == 0:
|
||||
self.ema_model.eval()
|
||||
|
||||
milestone = self.step // self.save_and_sample_every
|
||||
batches = num_to_groups(36, self.batch_size)
|
||||
all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches))
|
||||
all_images = torch.cat(all_images_list, dim=0)
|
||||
utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6)
|
||||
self.save(milestone)
|
||||
|
||||
self.step += 1
|
||||
pbar.update(1)
|
||||
|
||||
print("training complete")
|
||||
62
tests/test_modeling_utils.py
Executable file
62
tests/test_modeling_utils.py
Executable file
@@ -0,0 +1,62 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNetConfig, UNetModel
|
||||
|
||||
|
||||
global_rng = random.Random()
|
||||
|
||||
|
||||
def floats_tensor(shape, scale=1.0, rng=None, name=None):
|
||||
"""Creates a random float32 tensor"""
|
||||
if rng is None:
|
||||
rng = global_rng
|
||||
|
||||
total_dims = 1
|
||||
for dim in shape:
|
||||
total_dims *= dim
|
||||
|
||||
values = []
|
||||
for _ in range(total_dims):
|
||||
values.append(rng.random() * scale)
|
||||
|
||||
return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous()
|
||||
|
||||
|
||||
class ModelTesterMixin(unittest.TestCase):
|
||||
def test_from_pretrained_save_pretrained(self):
|
||||
config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2)
|
||||
model = UNetModel(config)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_pretrained(tmpdirname)
|
||||
new_model = UNetModel.from_pretrained(tmpdirname)
|
||||
|
||||
batch_size = 1
|
||||
num_channels = 3
|
||||
sizes = (32, 32)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
time_step = torch.tensor([10])
|
||||
|
||||
image = model(noise, time_step)
|
||||
new_image = new_model(noise, time_step)
|
||||
|
||||
assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"
|
||||
Reference in New Issue
Block a user