diff --git a/md5sum.saved b/md5sum.saved deleted file mode 100644 index 2dc63471f2..0000000000 --- a/md5sum.saved +++ /dev/null @@ -1 +0,0 @@ -ce075df80e7ba2391d63d026be165c15 src/diffusers/dependency_versions_table.py diff --git a/models/ddpm/run_ddpm.py b/models/ddpm/run_ddpm.py index 668c2c301e..6bf131f7b5 100755 --- a/models/ddpm/run_ddpm.py +++ b/models/ddpm/run_ddpm.py @@ -1,19 +1,10 @@ #!/usr/bin/env python3 import torch -from diffusers import GaussianDiffusion, UNetConfig, UNetModel +from diffusers import GaussianDiffusion, UNetModel -config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8)) -model = UNetModel(config) -print(model.config) - -model.save_pretrained("/home/patrick/diffusion_example") - -import ipdb - - -ipdb.set_trace() +model = UNetModel(dim=64, dim_mults=(1, 2, 4, 8)) diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 diff --git a/src/diffusers.egg-info/PKG-INFO b/src/diffusers.egg-info/PKG-INFO deleted file mode 100644 index 584c731df7..0000000000 --- a/src/diffusers.egg-info/PKG-INFO +++ /dev/null @@ -1,31 +0,0 @@ -Metadata-Version: 2.1 -Name: diffusers -Version: 0.0.1 -Summary: Diffusers -Home-page: https://github.com/huggingface/diffusers -Author: The HuggingFace team -Author-email: patrick@huggingface.co -License: Apache -Keywords: deep learning -Classifier: Development Status :: 5 - Production/Stable -Classifier: Intended Audience :: Developers -Classifier: Intended Audience :: Education -Classifier: Intended Audience :: Science/Research -Classifier: License :: OSI Approved :: Apache Software License -Classifier: Operating System :: OS Independent -Classifier: Programming Language :: Python :: 3 -Classifier: Programming Language :: Python :: 3.6 -Classifier: Programming Language :: Python :: 3.7 -Classifier: Programming Language :: Python :: 3.8 -Classifier: Programming Language :: Python :: 3.9 -Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence -Requires-Python: >=3.6.0 -Description-Content-Type: text/markdown -Provides-Extra: quality -Provides-Extra: docs -Provides-Extra: test -Provides-Extra: dev -Provides-Extra: sagemaker -License-File: LICENSE - -Super cool library about diffusion models diff --git a/src/diffusers.egg-info/SOURCES.txt b/src/diffusers.egg-info/SOURCES.txt deleted file mode 100644 index 2a49d211f3..0000000000 --- a/src/diffusers.egg-info/SOURCES.txt +++ /dev/null @@ -1,14 +0,0 @@ -LICENSE -README.md -pyproject.toml -setup.cfg -setup.py -src/diffusers/__init__.py -src/diffusers.egg-info/PKG-INFO -src/diffusers.egg-info/SOURCES.txt -src/diffusers.egg-info/dependency_links.txt -src/diffusers.egg-info/requires.txt -src/diffusers.egg-info/top_level.txt -src/diffusers/models/__init__.py -src/diffusers/models/unet.py -src/diffusers/samplers/__init__.py \ No newline at end of file diff --git a/src/diffusers.egg-info/dependency_links.txt b/src/diffusers.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789179..0000000000 --- a/src/diffusers.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/diffusers.egg-info/requires.txt b/src/diffusers.egg-info/requires.txt deleted file mode 100644 index 9d13868f00..0000000000 --- a/src/diffusers.egg-info/requires.txt +++ /dev/null @@ -1,31 +0,0 @@ -numpy>=1.17 -packaging>=20.0 -pyyaml -torch>=1.4.0 - -[dev] -black~=22.0 -isort>=5.5.4 -flake8>=3.8.3 -pytest -pytest-xdist -pytest-subtests -datasets -transformers - -[docs] - -[quality] -black~=22.0 -isort>=5.5.4 -flake8>=3.8.3 - -[sagemaker] -sagemaker - -[test] -pytest -pytest-xdist -pytest-subtests -datasets -transformers diff --git a/src/diffusers.egg-info/top_level.txt b/src/diffusers.egg-info/top_level.txt deleted file mode 100644 index 6033efb6db..0000000000 --- a/src/diffusers.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ -diffusers diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index ad66a4be34..41f680261c 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,5 +4,5 @@ __version__ = "0.0.1" -from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else -from .models.unet import UNetConfig, UNetModel +from .models.unet import UNetModel +from .samplers.gaussian import GaussianDiffusion diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py old mode 100755 new mode 100644 index 1c25c463b5..bb30751c8e --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -20,11 +20,11 @@ import copy import json import os import re +import inspect from typing import Any, Dict, Tuple, Union from requests import HTTPError from transformers.utils import ( - CONFIG_NAME, HUGGINGFACE_CO_RESOLVE_ENDPOINT, EntryNotFoundError, RepositoryNotFoundError, @@ -44,33 +44,34 @@ logger = logging.get_logger(__name__) _re_configuration_file = re.compile(r"config\.(.*)\.json") -class PretrainedConfig: +class Config: r""" Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as methods for loading/downloading/saving configurations. """ - model_type: str = "" + config_name = None - def __init__(self, **kwargs): - # Name or path to the pretrained checkpoint - self._name_or_path = str(kwargs.pop("name_or_path", "")) + def register(self, **kwargs): + if self.config_name is None: + raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`") + kwargs["_class_name"] = self.__class__.__name__ + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err - # Drop the diffusers version info - self.diffusers_version = kwargs.pop("diffusers_version", None) + if not hasattr(self, "_dict_to_save"): + self._dict_to_save = {} - @property - def name_or_path(self) -> str: - return getattr(self, "_name_or_path", None) + self._dict_to_save.update(kwargs) - @name_or_path.setter - def name_or_path(self, value): - self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) - - def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the - [`~PretrainedConfig.from_pretrained`] class method. + [`~Config.from_config`] class method. Args: save_directory (`str` or `os.PathLike`): @@ -83,119 +84,15 @@ class PretrainedConfig: os.makedirs(save_directory, exist_ok=True) - # If we save using the predefined names, we can load using `from_pretrained` - output_config_file = os.path.join(save_directory, CONFIG_NAME) + # If we save using the predefined names, we can load using `from_config` + output_config_file = os.path.join(save_directory, self.config_name) - self.to_json_file(output_config_file, use_diff=True) + self.to_json_file(output_config_file) logger.info(f"Configuration saved in {output_config_file}") @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": - r""" - Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. - - Args: - pretrained_model_name_or_path (`str` or `os.PathLike`): - This can be either: - - - a string, the *model id* of a pretrained model configuration hosted inside a model repo on - huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or - namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. - - a path to a *directory* containing a configuration file saved using the - [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`. - - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`. - cache_dir (`str` or `os.PathLike`, *optional*): - Path to a directory in which a downloaded pretrained model configuration should be cached if the - standard cache should not be used. - force_download (`bool`, *optional*, defaults to `False`): - Whether or not to force to (re-)download the configuration files and override the cached versions if - they exist. - resume_download (`bool`, *optional*, defaults to `False`): - Whether or not to delete incompletely received file. Attempts to resume the download if such a file - exists. - proxies (`Dict[str, str]`, *optional*): - A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', - 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. - use_auth_token (`str` or *bool*, *optional*): - The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated - when running `diffusers-cli login` (stored in `~/.huggingface`). - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - return_unused_kwargs (`bool`, *optional*, defaults to `False`): - If `False`, then this function returns just the final configuration object. - - If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a - dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the - part of `kwargs` which has not been used to update `config` and is otherwise ignored. - kwargs (`Dict[str, Any]`, *optional*): - The values in kwargs of any keys which are configuration attributes will be used to override the loaded - values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled - by the `return_unused_kwargs` keyword parameter. - - - - Passing `use_auth_token=True` is required when you want to use a private model. - - - - Returns: - [`PretrainedConfig`]: The configuration object instantiated from this pretrained model. - - Examples: - - ```python - # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a - # derived class: BertConfig - config = BertConfig.from_pretrained( - "bert-base-uncased" - ) # Download configuration from huggingface.co and cache. - config = BertConfig.from_pretrained( - "./test/saved_model/" - ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* - config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") - config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False) - assert config.output_attentions == True - config, unused_kwargs = BertConfig.from_pretrained( - "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True - ) - assert config.output_attentions == True - assert unused_kwargs == {"foo": False} - ```""" - config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) - if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: - logger.warning( - f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " - f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." - ) - - return cls.from_dict(config_dict, **kwargs) - - @classmethod - def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs - ) -> Tuple[Dict[str, Any], Dict[str, Any]]: - """ - From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a - [`PretrainedConfig`] using `from_dict`. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): - The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. - - Returns: - `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. - - """ - # Get config dict associated with the base config file - config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) - - return config_dict, kwargs - - @classmethod - def _get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + def from_config( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], return_unused_kwargs=False, **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: cache_dir = kwargs.pop("cache_dir", None) force_download = kwargs.pop("force_download", False) @@ -215,7 +112,7 @@ class PretrainedConfig: if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): config_file = pretrained_model_name_or_path else: - configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) + configuration_file = cls.config_name if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, configuration_file) @@ -286,58 +183,27 @@ class PretrainedConfig: else: logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") - return config_dict, kwargs + expected_keys = set(dict(inspect.signature(cls.__init__).parameters).keys()) + expected_keys.remove("self") - @classmethod - def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": - """ - Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. + passed_keys = set(config_dict.keys()) - Args: - config_dict (`Dict[str, Any]`): - Dictionary that will be used to instantiate the configuration object. Such a dictionary can be - retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. - kwargs (`Dict[str, Any]`): - Additional parameters from which to initialize the configuration object. + unused_kwargs = kwargs + for key in passed_keys - expected_keys: + unused_kwargs[key] = config_dict.pop(key) - Returns: - [`PretrainedConfig`]: The configuration object instantiated from those parameters. - """ - return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) - # Those arguments may be passed along for our internal telemetry. - # We remove them so they don't appear in `return_unused_kwargs`. + if len(expected_keys - passed_keys) > 0: + logger.warn( + f"{expected_keys - passed_keys} was not found in config. " + f"Values will be initialized to default values." + ) - config = cls(**config_dict) + model = cls(**config_dict) - to_remove = [] - for key, value in kwargs.items(): - if hasattr(config, key): - setattr(config, key, value) - to_remove.append(key) - for key in to_remove: - kwargs.pop(key, None) - - logger.info(f"Model config {config}") if return_unused_kwargs: - return config, kwargs + return model, unused_kwargs else: - return config - - @classmethod - def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": - """ - Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. - - Args: - json_file (`str` or `os.PathLike`): - Path to the JSON file containing the parameters. - - Returns: - [`PretrainedConfig`]: The configuration object instantiated from that JSON file. - - """ - config_dict = cls._dict_from_json_file(json_file) - return cls(**config_dict) + return model @classmethod def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): @@ -351,38 +217,6 @@ class PretrainedConfig: def __repr__(self): return f"{self.__class__.__name__} {self.to_json_string()}" - def to_diff_dict(self) -> Dict[str, Any]: - """ - Removes all attributes from config which correspond to the default config attributes for better readability and - serializes to a Python dictionary. - - Returns: - `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, - """ - config_dict = self.to_dict() - - # get the default config dict - default_config_dict = PretrainedConfig().to_dict() - - # get class specific config dict - class_config_dict = self.__class__().to_dict() - - serializable_config_dict = {} - - # only serialize values that differ from the default config - for key, value in config_dict.items(): - if ( - key not in default_config_dict - or key == "diffusers_version" - or value != default_config_dict[key] - or (key in class_config_dict and value != class_config_dict[key]) - ): - serializable_config_dict[key] = value - - self.dict_torch_dtype_to_str(serializable_config_dict) - - return serializable_config_dict - def to_dict(self) -> Dict[str, Any]: """ Serializes this instance to a Python dictionary. @@ -391,106 +225,29 @@ class PretrainedConfig: `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. """ output = copy.deepcopy(self.__dict__) - if hasattr(self.__class__, "model_type"): - output["model_type"] = self.__class__.model_type - if "_auto_class" in output: - del output["_auto_class"] - # Transformers version when serializing the model + # Diffusion version when serializing the model output["diffusers_version"] = __version__ - self.dict_torch_dtype_to_str(output) - return output - def to_json_string(self, use_diff: bool = True) -> str: + def to_json_string(self) -> str: """ Serializes this instance to a JSON string. - Args: - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` - is serialized to JSON string. - Returns: `str`: String containing all the attributes that make up this configuration instance in JSON format. """ - if use_diff is True: - config_dict = self.to_diff_dict() - else: - config_dict = self.to_dict() + config_dict = self._dict_to_save return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" - def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + def to_json_file(self, json_file_path: Union[str, os.PathLike]): """ Save this instance to a JSON file. Args: json_file_path (`str` or `os.PathLike`): Path to the JSON file in which this configuration instance's parameters will be saved. - use_diff (`bool`, *optional*, defaults to `True`): - If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` - is serialized to JSON file. """ with open(json_file_path, "w", encoding="utf-8") as writer: - writer.write(self.to_json_string(use_diff=use_diff)) - - def update(self, config_dict: Dict[str, Any]): - """ - Updates attributes of this class with attributes from `config_dict`. - - Args: - config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class. - """ - for key, value in config_dict.items(): - setattr(self, key, value) - - def update_from_string(self, update_str: str): - """ - Updates attributes of this class with attributes from `update_str`. - - The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example: - "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" - - The keys to change have to already exist in the config object. - - Args: - update_str (`str`): String with attributes that should be updated for this class. - - """ - - d = dict(x.split("=") for x in update_str.split(",")) - for k, v in d.items(): - if not hasattr(self, k): - raise ValueError(f"key {k} isn't in the original config dict") - - old_v = getattr(self, k) - if isinstance(old_v, bool): - if v.lower() in ["true", "1", "y", "yes"]: - v = True - elif v.lower() in ["false", "0", "n", "no"]: - v = False - else: - raise ValueError(f"can't derive true or false from {v} (key {k})") - elif isinstance(old_v, int): - v = int(v) - elif isinstance(old_v, float): - v = float(v) - elif not isinstance(old_v, str): - raise ValueError( - f"You can only update int, float, bool or string values in the config, got {v} for key {k}" - ) - - setattr(self, k, v) - - def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: - """ - Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, - converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* - string, which can then be stored in the json format. - """ - if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): - d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] - for value in d.values(): - if isinstance(value, dict): - self.dict_torch_dtype_to_str(value) + writer.write(self.to_json_string()) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index a7e2b3895f..d346fb7400 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -33,10 +33,9 @@ from transformers.utils import ( is_offline_mode, is_remote_url, logging, + CONFIG_NAME, ) -from .configuration_utils import PretrainedConfig - WEIGHTS_NAME = "diffusion_model.pt" @@ -135,7 +134,7 @@ class PreTrainedModel(torch.nn.Module): Class attributes (overridden by derived classes): - - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + - **config_class** ([`Config`]) -- A subclass of [`Config`] to use as configuration class for this model architecture. - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, taking as arguments: @@ -150,35 +149,16 @@ class PreTrainedModel(torch.nn.Module): - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP models, `pixel_values` for vision models and `input_values` for speech models). """ - config_class = None + config_name = CONFIG_NAME - def __init__(self, config: PretrainedConfig): + def __init__(self): super().__init__() - if not isinstance(config, PretrainedConfig): - raise ValueError( - f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " - "`PretrainedConfig`. To create a model from a pretrained model use " - f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" - ) - # Save config and origin of the pretrained weights if given in model - self.config = config - self.name_or_path = config.name_or_path - - @classmethod - def _from_config(cls, config, **kwargs): - """ - All context managers that the model should be initialized under go here. - """ - model = cls(config, **kwargs) - - return model def save_pretrained( self, save_directory: Union[str, os.PathLike], is_main_process: bool = True, save_function: Callable = torch.save, - push_to_hub: bool = False, **kwargs, ): """ @@ -195,16 +175,6 @@ class PreTrainedModel(torch.nn.Module): save_function (`Callable`): The function to use to save the state dictionary. Useful on distributed training like TPUs when one need to replace `torch.save` by another method. - push_to_hub (`bool`, *optional*, defaults to `False`): - Whether or not to push your model to the Hugging Face model hub after saving it. - - - - Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, - which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing - folder. Pass along `temp_dir=True` to use a temporary directory instead. - - kwargs: Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. @@ -218,11 +188,9 @@ class PreTrainedModel(torch.nn.Module): model_to_save = self # Attach architecture to the config - model_to_save.config.architectures = [model_to_save.__class__.__name__] - # Save the config if is_main_process: - model_to_save.config.save_pretrained(save_directory) + model_to_save.save_config(save_directory) # Save the model state_dict = model_to_save.state_dict() @@ -241,7 +209,7 @@ class PreTrainedModel(torch.nn.Module): logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") @classmethod - def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs): r""" Instantiate a pretrained pytorch model from a pre-trained model configuration. @@ -265,11 +233,11 @@ class PreTrainedModel(torch.nn.Module): - A path to a *directory* containing model weights saved using [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. - config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + config (`Union[Config, str, os.PathLike]`, *optional*): Can be either: - - an instance of a class derived from [`PretrainedConfig`], - - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + - an instance of a class derived from [`Config`], + - a string or path valid as input to [`~Config.from_pretrained`]. Configuration for the model to use instead of an automatically loaded configuration. Configuration can be automatically loaded when: @@ -327,7 +295,7 @@ class PreTrainedModel(torch.nn.Module): underlying model's `__init__` method (we assume all relevant updates to the configuration have already been done) - If a configuration is not provided, `kwargs` will be first passed to the configuration class - initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + initialization function ([`~Config.from_pretrained`]). Each key of `kwargs` that corresponds to a configuration attribute will be used to override said attribute with the supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute will be passed to the underlying model's `__init__` function. @@ -356,7 +324,6 @@ class PreTrainedModel(torch.nn.Module): use_auth_token = kwargs.pop("use_auth_token", None) revision = kwargs.pop("revision", None) mirror = kwargs.pop("mirror", None) - from_pipeline = kwargs.pop("_from_pipeline", None) from_auto_class = kwargs.pop("_from_auto", False) user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} @@ -367,7 +334,7 @@ class PreTrainedModel(torch.nn.Module): # Load config if we don't provide a configuration config_path = pretrained_model_name_or_path - config, model_kwargs = cls.config_class.from_pretrained( + model, unused_kwargs = cls.from_config( config_path, cache_dir=cache_dir, return_unused_kwargs=True, @@ -377,12 +344,9 @@ class PreTrainedModel(torch.nn.Module): local_files_only=local_files_only, use_auth_token=use_auth_token, revision=revision, - _from_auto=from_auto_class, - _from_pipeline=from_pipeline, **kwargs, ) - model_kwargs = kwargs - + model.register(name_or_path=pretrained_model_name_or_path) # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the # Load model pretrained_model_name_or_path = str(pretrained_model_name_or_path) @@ -456,18 +420,8 @@ class PreTrainedModel(torch.nn.Module): else: logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") - state_dict = load_state_dict(resolved_archive_file) - # set dtype to instantiate the model under: - # 1. If torch_dtype is not None, we use that dtype - # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first - # weights entry - we assume all weights are of the same dtype - # we also may have config.torch_dtype available, but we won't rely on it till v5 - - config.name_or_path = pretrained_model_name_or_path - - model = cls(config, *model_args, **model_kwargs) - # restore default dtype + state_dict = load_state_dict(resolved_archive_file) model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( model, state_dict, diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 364653676d..54b1d3ead3 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from . import unet +from .unet import UNetModel diff --git a/src/diffusers/models/unet/modeling_unet.py b/src/diffusers/models/unet.py similarity index 62% rename from src/diffusers/models/unet/modeling_unet.py rename to src/diffusers/models/unet.py index 23c0921633..6b8069d1cb 100644 --- a/src/diffusers/models/unet/modeling_unet.py +++ b/src/diffusers/models/unet.py @@ -22,27 +22,22 @@ from inspect import isfunction from pathlib import Path import torch -import torch.nn.functional as F from torch import einsum, nn from torch.cuda.amp import GradScaler, autocast from torch.optim import Adam from torch.utils import data from einops import rearrange -from PIL import Image -from torchvision import transforms, utils +from torchvision import utils, transforms from tqdm import tqdm -from ...modeling_utils import PreTrainedModel -from .configuration_unet import UNetConfig - +from ..configuration_utils import Config +from ..modeling_utils import PreTrainedModel +from PIL import Image # NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py - - - def exists(x): return x is not None @@ -246,13 +241,30 @@ class Attention(nn.Module): return self.to_out(out) -class UNetModel(PreTrainedModel): - - config_class = UNetConfig - - def __init__(self, config): - super().__init__(config) +class UNetModel(PreTrainedModel, Config): + def __init__( + self, + dim=64, + dim_mults=(1, 2, 4, 8), + init_dim=None, + out_dim=None, + channels=3, + with_time_emb=True, + resnet_block_groups=8, + learned_variance=False, + ): + super().__init__() + self.register( + dim=dim, + dim_mults=dim_mults, + init_dim=init_dim, + out_dim=out_dim, + channels=channels, + with_time_emb=with_time_emb, + resnet_block_groups=resnet_block_groups, + learned_variance=learned_variance, + ) init_dim = None out_dim = None channels = 3 @@ -262,9 +274,9 @@ class UNetModel(PreTrainedModel): # determine dimensions - dim_mults = config.dim_mults - dim = config.dim - self.channels = config.channels + dim_mults = dim_mults + dim = dim + self.channels = channels init_dim = default(init_dim, dim // 3 * 2) self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) @@ -357,238 +369,6 @@ class UNetModel(PreTrainedModel): return self.final_conv(x) -# gaussian diffusion trainer class - - -def extract(a, t, x_shape): - b, *_ = t.shape - out = a.gather(-1, t) - return out.reshape(b, *((1,) * (len(x_shape) - 1))) - - -def noise_like(shape, device, repeat=False): - def repeat_noise(): - return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) - - def noise(): - return torch.randn(shape, device=device) - - return repeat_noise() if repeat else noise() - - -def linear_beta_schedule(timesteps): - scale = 1000 / timesteps - beta_start = scale * 0.0001 - beta_end = scale * 0.02 - return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) - - -def cosine_beta_schedule(timesteps, s=0.008): - """ - cosine schedule - as proposed in https://openreview.net/forum?id=-NEXDKk8gZ - """ - steps = timesteps + 1 - x = torch.linspace(0, timesteps, steps, dtype=torch.float64) - alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 - alphas_cumprod = alphas_cumprod / alphas_cumprod[0] - betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) - return torch.clip(betas, 0, 0.999) - - -class GaussianDiffusion(nn.Module): - def __init__( - self, - denoise_fn, - *, - image_size, - channels=3, - timesteps=1000, - loss_type="l1", - objective="pred_noise", - beta_schedule="cosine", - ): - super().__init__() - assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim) - - self.channels = channels - self.image_size = image_size - self.denoise_fn = denoise_fn - self.objective = objective - - if beta_schedule == "linear": - betas = linear_beta_schedule(timesteps) - elif beta_schedule == "cosine": - betas = cosine_beta_schedule(timesteps) - else: - raise ValueError(f"unknown beta schedule {beta_schedule}") - - alphas = 1.0 - betas - alphas_cumprod = torch.cumprod(alphas, axis=0) - alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) - - (timesteps,) = betas.shape - self.num_timesteps = int(timesteps) - self.loss_type = loss_type - - # helper function to register buffer from float64 to float32 - - def register_buffer(name, val): - self.register_buffer(name, val.to(torch.float32)) - - register_buffer("betas", betas) - register_buffer("alphas_cumprod", alphas_cumprod) - register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) - - # calculations for diffusion q(x_t | x_{t-1}) and others - - register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) - register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) - register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) - register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) - register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) - - # calculations for posterior q(x_{t-1} | x_t, x_0) - - posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) - - # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) - - register_buffer("posterior_variance", posterior_variance) - - # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain - - register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20))) - register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) - register_buffer( - "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) - ) - - def predict_start_from_noise(self, x_t, t, noise): - return ( - extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise - ) - - def q_posterior(self, x_start, x_t, t): - posterior_mean = ( - extract(self.posterior_mean_coef1, t, x_t.shape) * x_start - + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t - ) - posterior_variance = extract(self.posterior_variance, t, x_t.shape) - posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) - return posterior_mean, posterior_variance, posterior_log_variance_clipped - - def p_mean_variance(self, x, t, clip_denoised: bool): - model_output = self.denoise_fn(x, t) - - if self.objective == "pred_noise": - x_start = self.predict_start_from_noise(x, t=t, noise=model_output) - elif self.objective == "pred_x0": - x_start = model_output - else: - raise ValueError(f"unknown objective {self.objective}") - - if clip_denoised: - x_start.clamp_(-1.0, 1.0) - - model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) - return model_mean, posterior_variance, posterior_log_variance - - @torch.no_grad() - def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): - b, *_, device = *x.shape, x.device - model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) - noise = noise_like(x.shape, device, repeat_noise) - # no noise when t == 0 - nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) - result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise - return result - - @torch.no_grad() - def p_sample_loop(self, shape): - device = self.betas.device - - b = shape[0] - img = torch.randn(shape, device=device) - - for i in tqdm( - reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps - ): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) - - img = unnormalize_to_zero_to_one(img) - return img - - @torch.no_grad() - def sample(self, batch_size=16): - image_size = self.image_size - channels = self.channels - return self.p_sample_loop((batch_size, channels, image_size, image_size)) - - @torch.no_grad() - def interpolate(self, x1, x2, t=None, lam=0.5): - b, *_, device = *x1.shape, x1.device - t = default(t, self.num_timesteps - 1) - - assert x1.shape == x2.shape - - t_batched = torch.stack([torch.tensor(t, device=device)] * b) - xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) - - img = (1 - lam) * xt1 + lam * xt2 - for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): - img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) - - return img - - def q_sample(self, x_start, t, noise=None): - noise = default(noise, lambda: torch.randn_like(x_start)) - - return ( - extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start - + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise - ) - - @property - def loss_fn(self): - if self.loss_type == "l1": - return F.l1_loss - elif self.loss_type == "l2": - return F.mse_loss - else: - raise ValueError(f"invalid loss type {self.loss_type}") - - def p_losses(self, x_start, t, noise=None): - b, c, h, w = x_start.shape - noise = default(noise, lambda: torch.randn_like(x_start)) - - x = self.q_sample(x_start=x_start, t=t, noise=noise) - model_out = self.denoise_fn(x, t) - - if self.objective == "pred_noise": - target = noise - elif self.objective == "pred_x0": - target = x_start - else: - raise ValueError(f"unknown objective {self.objective}") - - loss = self.loss_fn(model_out, target) - return loss - - def forward(self, img, *args, **kwargs): - b, _, h, w, device, img_size, = ( - *img.shape, - img.device, - self.image_size, - ) - assert h == img_size and w == img_size, f"height and width of image must be {img_size}" - t = torch.randint(0, self.num_timesteps, (b,), device=device).long() - - img = normalize_to_neg_one_to_one(img) - return self.p_losses(img, t, *args, **kwargs) - - # dataset classes @@ -621,6 +401,7 @@ class Dataset(data.Dataset): class Trainer(object): + def __init__( self, diffusion_model, diff --git a/src/diffusers/models/unet/__init__.py b/src/diffusers/models/unet/__init__.py deleted file mode 100644 index 01b2a02857..0000000000 --- a/src/diffusers/models/unet/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# flake8: noqa -# There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all - -from .configuration_unet import UNetConfig -from .modeling_unet import GaussianDiffusion, UNetModel diff --git a/src/diffusers/models/unet/configuration_unet.py b/src/diffusers/models/unet/configuration_unet.py deleted file mode 100644 index a8a1e12c00..0000000000 --- a/src/diffusers/models/unet/configuration_unet.py +++ /dev/null @@ -1,45 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and - -# limitations under the License. - -# helpers functions - -# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py -from ...configuration_utils import PretrainedConfig - - -class UNetConfig(PretrainedConfig): - model_type = "unet" - - def __init__( - self, - dim=64, - dim_mults=(1, 2, 4, 8), - init_dim=None, - out_dim=None, - channels=3, - with_time_emb=True, - resnet_block_groups=8, - learned_variance=False, - **kwargs, - ): - super().__init__(**kwargs) - self.dim = dim - self.dim_mults = dim_mults - self.init_dim = init_dim - self.out_dim = out_dim - self.channels = channels - self.with_time_emb = with_time_emb - self.resnet_block_groups = resnet_block_groups - self.learned_variance = learned_variance diff --git a/src/diffusers/samplers/__init__.py b/src/diffusers/samplers/__init__.py index 99977b417e..76aa8aab0c 100644 --- a/src/diffusers/samplers/__init__.py +++ b/src/diffusers/samplers/__init__.py @@ -1,3 +1,7 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all. + # Copyright 2022 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -11,3 +15,5 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .gaussian import GaussianDiffusion diff --git a/src/diffusers/samplers/gaussian.py b/src/diffusers/samplers/gaussian.py new file mode 100644 index 0000000000..4a7e350704 --- /dev/null +++ b/src/diffusers/samplers/gaussian.py @@ -0,0 +1,312 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch.nn.functional as F +from torch import nn +from inspect import isfunction +from tqdm import tqdm + +from ..configuration_utils import Config +SAMPLING_CONFIG_NAME = "sampler_config.json" + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cycle(dl): + while True: + for data_dl in dl: + yield data_dl + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + + +# small helper modules + + +class EMA: + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + + def noise(): + return torch.randn(shape, device=device) + + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + + +class GaussianDiffusion(nn.Module, Config): + + config_name = SAMPLING_CONFIG_NAME + + def __init__( + self, + image_size, + channels=3, + timesteps=1000, + loss_type="l1", + objective="pred_noise", + beta_schedule="cosine", + ): + super().__init__() + self.register( + image_size=image_size, + channels=channels, + timesteps=timesteps, + loss_type=loss_type, + objective=objective, + beta_schedule=beta_schedule, + ) + + self.channels = channels + self.image_size = image_size + self.objective = objective + + if beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + else: + raise ValueError(f"unknown beta schedule {beta_schedule}") + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # helper function to register buffer from float64 to float32 + + def register_buffer(name, val): + self.register_buffer(name, val.to(torch.float32)) + + register_buffer("betas", betas) + register_buffer("alphas_cumprod", alphas_cumprod) + register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) + register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) + register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) + register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) + register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer("posterior_variance", posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + register_buffer( + "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) + ) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised: bool): + model_output = model(x, t) + + if self.objective == "pred_noise": + x_start = self.predict_start_from_noise(x, t=t, noise=model_output) + elif self.objective == "pred_x0": + x_start = model_output + else: + raise ValueError(f"unknown objective {self.objective}") + + if clip_denoised: + x_start.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, model, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(model=model, x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return result + + @torch.no_grad() + def p_sample_loop(self, model, shape): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + + for i in tqdm( + reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps + ): + img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long)) + + img = unnormalize_to_zero_to_one(img) + return img + + @torch.no_grad() + def sample(self, model, batch_size=16): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop(model, (batch_size, channels, image_size, image_size)) + + @torch.no_grad() + def interpolate(self, model, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): + img = self.p_sample(model, img, torch.full((b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + @property + def loss_fn(self): + if self.loss_type == "l1": + return F.l1_loss + elif self.loss_type == "l2": + return F.mse_loss + else: + raise ValueError(f"invalid loss type {self.loss_type}") + + def p_losses(self, model, x_start, t, noise=None): + b, c, h, w = x_start.shape + noise = default(noise, lambda: torch.randn_like(x_start)) + + x = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = model(x, t) + + if self.objective == "pred_noise": + target = noise + elif self.objective == "pred_x0": + target = x_start + else: + raise ValueError(f"unknown objective {self.objective}") + + loss = self.loss_fn(model_out, target) + return loss + + def forward(self, model, img, *args, **kwargs): + b, _, h, w, device, img_size, = ( + *img.shape, + img.device, + self.image_size, + ) + assert h == img_size and w == img_size, f"height and width of image must be {img_size}" + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + + img = normalize_to_neg_one_to_one(img) + return self.p_losses(model, img, t, *args, **kwargs) diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 92612f9831..1233980f39 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -19,7 +19,7 @@ import unittest import torch -from diffusers import UNetConfig, UNetModel +from diffusers import GaussianDiffusion, UNetModel global_rng = random.Random() @@ -42,7 +42,6 @@ def floats_tensor(shape, scale=1.0, rng=None, name=None): class ModelTesterMixin(unittest.TestCase): - @property def dummy_input(self): batch_size = 1 @@ -55,8 +54,7 @@ class ModelTesterMixin(unittest.TestCase): return (noise, time_step) def test_from_pretrained_save_pretrained(self): - config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2) - model = UNetModel(config) + model = UNetModel(dim=8, dim_mults=(1, 2), resnet_block_groups=2) with tempfile.TemporaryDirectory() as tmpdirname: model.save_pretrained(tmpdirname) @@ -75,3 +73,34 @@ class ModelTesterMixin(unittest.TestCase): image = model(*self.dummy_input) assert image is not None, "Make sure output is not None" + + +class SamplerTesterMixin(unittest.TestCase): + + @property + def dummy_model(self): + return UNetModel.from_pretrained("fusing/ddpm_dummy") + + def test_from_pretrained_save_pretrained(self): + sampler = GaussianDiffusion(image_size=128, timesteps=3, loss_type="l1") + + with tempfile.TemporaryDirectory() as tmpdirname: + sampler.save_config(tmpdirname) + new_sampler = GaussianDiffusion.from_config(tmpdirname, return_unused=False) + + model = self.dummy_model + + torch.manual_seed(0) + sampled_out = sampler.sample(model, batch_size=1) + torch.manual_seed(0) + sampled_out_new = new_sampler.sample(model, batch_size=1) + + assert (sampled_out - sampled_out_new).abs().sum() < 1e-5, "Samplers don't give the same output" + + def test_from_pretrained_hub(self): + sampler = GaussianDiffusion.from_config("fusing/ddpm_dummy") + model = self.dummy_model + + sampled_out = sampler.sample(model, batch_size=1) + + assert sampled_out is not None, "Make sure output is not None"