diff --git a/models/ddpm/run_ddpm.py b/models/ddpm/run_ddpm.py new file mode 100755 index 0000000000..668c2c301e --- /dev/null +++ b/models/ddpm/run_ddpm.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 +import torch + +from diffusers import GaussianDiffusion, UNetConfig, UNetModel + + +config = UNetConfig(dim=64, dim_mults=(1, 2, 4, 8)) +model = UNetModel(config) +print(model.config) + +model.save_pretrained("/home/patrick/diffusion_example") + +import ipdb + + +ipdb.set_trace() + +diffusion = GaussianDiffusion(model, image_size=128, timesteps=1000, loss_type="l1") # number of steps # L1 or L2 + +training_images = torch.randn(8, 3, 128, 128) # your images need to be normalized from a range of -1 to +1 +loss = diffusion(training_images) +loss.backward() +# after a lot of training + +sampled_images = diffusion.sample(batch_size=4) +sampled_images.shape # (4, 3, 128, 128) diff --git a/src/diffusers/__init__.py b/src/diffusers/__init__.py index e438c924a8..ad66a4be34 100644 --- a/src/diffusers/__init__.py +++ b/src/diffusers/__init__.py @@ -4,4 +4,5 @@ __version__ = "0.0.1" -from .models import UNetModel +from .models.unet import GaussianDiffusion # TODO(PVP): move somewhere else +from .models.unet import UNetConfig, UNetModel diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py new file mode 100755 index 0000000000..1c25c463b5 --- /dev/null +++ b/src/diffusers/configuration_utils.py @@ -0,0 +1,496 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" Configuration base class and utilities.""" + + +import copy +import json +import os +import re +from typing import Any, Dict, Tuple, Union + +from requests import HTTPError +from transformers.utils import ( + CONFIG_NAME, + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + cached_path, + hf_bucket_url, + is_offline_mode, + is_remote_url, + logging, +) + +from . import __version__ + + +logger = logging.get_logger(__name__) + +_re_configuration_file = re.compile(r"config\.(.*)\.json") + + +class PretrainedConfig: + r""" + Base class for all configuration classes. Handles a few parameters common to all models' configurations as well as + methods for loading/downloading/saving configurations. + + """ + model_type: str = "" + + def __init__(self, **kwargs): + # Name or path to the pretrained checkpoint + self._name_or_path = str(kwargs.pop("name_or_path", "")) + + # Drop the diffusers version info + self.diffusers_version = kwargs.pop("diffusers_version", None) + + @property + def name_or_path(self) -> str: + return getattr(self, "_name_or_path", None) + + @name_or_path.setter + def name_or_path(self, value): + self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + + Args: + save_directory (`str` or `os.PathLike`): + Directory where the configuration JSON file will be saved (will be created if it does not exist). + kwargs: + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file, use_diff=True) + logger.info(f"Configuration saved in {output_config_file}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs) -> "PretrainedConfig": + r""" + Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + This can be either: + + - a string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. Valid model ids can be located at the root-level, like `bert-base-uncased`, or + namespaced under a user or organization name, like `dbmdz/bert-base-german-cased`. + - a path to a *directory* containing a configuration file saved using the + [`~PretrainedConfig.save_pretrained`] method, e.g., `./my_model_directory/`. + - a path or url to a saved configuration JSON *file*, e.g., `./my_model_directory/configuration.json`. + cache_dir (`str` or `os.PathLike`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force to (re-)download the configuration files and override the cached versions if + they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received file. Attempts to resume the download if such a file + exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request. + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `diffusers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + return_unused_kwargs (`bool`, *optional*, defaults to `False`): + If `False`, then this function returns just the final configuration object. + + If `True`, then this functions returns a `Tuple(config, unused_kwargs)` where *unused_kwargs* is a + dictionary consisting of the key/value pairs whose keys are not configuration attributes: i.e., the + part of `kwargs` which has not been used to update `config` and is otherwise ignored. + kwargs (`Dict[str, Any]`, *optional*): + The values in kwargs of any keys which are configuration attributes will be used to override the loaded + values. Behavior concerning key/value pairs whose keys are *not* configuration attributes is controlled + by the `return_unused_kwargs` keyword parameter. + + + + Passing `use_auth_token=True` is required when you want to use a private model. + + + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from this pretrained model. + + Examples: + + ```python + # We can't instantiate directly the base class *PretrainedConfig* so let's show the examples on a + # derived class: BertConfig + config = BertConfig.from_pretrained( + "bert-base-uncased" + ) # Download configuration from huggingface.co and cache. + config = BertConfig.from_pretrained( + "./test/saved_model/" + ) # E.g. config (or model) was saved using *save_pretrained('./test/saved_model/')* + config = BertConfig.from_pretrained("./test/saved_model/my_configuration.json") + config = BertConfig.from_pretrained("bert-base-uncased", output_attentions=True, foo=False) + assert config.output_attentions == True + config, unused_kwargs = BertConfig.from_pretrained( + "bert-base-uncased", output_attentions=True, foo=False, return_unused_kwargs=True + ) + assert config.output_attentions == True + assert unused_kwargs == {"foo": False} + ```""" + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + [`PretrainedConfig`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. + + """ + # Get config dict associated with the base config file + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + return config_dict, kwargs + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + use_auth_token = kwargs.pop("use_auth_token", None) + local_files_only = kwargs.pop("local_files_only", False) + revision = kwargs.pop("revision", None) + + user_agent = {"file_type": "config"} + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + config_file = pretrained_model_name_or_path + else: + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) + + if os.path.isdir(pretrained_model_name_or_path): + config_file = os.path.join(pretrained_model_name_or_path, configuration_file) + else: + config_file = hf_bucket_url( + pretrained_model_name_or_path, filename=configuration_file, revision=revision, mirror=None + ) + + try: + # Load from URL or cache if already cached + resolved_config_file = cached_path( + config_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier listed on " + "'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a token having " + "permission to this repo with `use_auth_token` or log in with `huggingface-cli login` and pass " + "`use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for this " + f"model name. Check the model page at 'https://huggingface.co/{pretrained_model_name_or_path}' for " + "available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named {configuration_file}." + ) + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it in" + f" the cached files and it looks like {pretrained_model_name_or_path} is not the path to a directory" + f" containing a {configuration_file} file.\nCheckout your internet connection or see how to run the" + " library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a {configuration_file} file" + ) + + try: + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) + except (json.JSONDecodeError, UnicodeDecodeError): + raise EnvironmentError( + f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + ) + + if resolved_config_file == config_file: + logger.info(f"loading configuration file {config_file}") + else: + logger.info(f"loading configuration file {config_file} from cache at {resolved_config_file}") + + return config_dict, kwargs + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "PretrainedConfig": + """ + Instantiates a [`PretrainedConfig`] from a Python dictionary of parameters. + + Args: + config_dict (`Dict[str, Any]`): + Dictionary that will be used to instantiate the configuration object. Such a dictionary can be + retrieved from a pretrained checkpoint by leveraging the [`~PretrainedConfig.get_config_dict`] method. + kwargs (`Dict[str, Any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from those parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + # Those arguments may be passed along for our internal telemetry. + # We remove them so they don't appear in `return_unused_kwargs`. + + config = cls(**config_dict) + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info(f"Model config {config}") + if return_unused_kwargs: + return config, kwargs + else: + return config + + @classmethod + def from_json_file(cls, json_file: Union[str, os.PathLike]) -> "PretrainedConfig": + """ + Instantiates a [`PretrainedConfig`] from the path to a JSON file of parameters. + + Args: + json_file (`str` or `os.PathLike`): + Path to the JSON file containing the parameters. + + Returns: + [`PretrainedConfig`]: The configuration object instantiated from that JSON file. + + """ + config_dict = cls._dict_from_json_file(json_file) + return cls(**config_dict) + + @classmethod + def _dict_from_json_file(cls, json_file: Union[str, os.PathLike]): + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def __eq__(self, other): + return self.__dict__ == other.__dict__ + + def __repr__(self): + return f"{self.__class__.__name__} {self.to_json_string()}" + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = self.__class__().to_dict() + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + key not in default_config_dict + or key == "diffusers_version" + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + self.dict_torch_dtype_to_str(serializable_config_dict) + + return serializable_config_dict + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + + # Transformers version when serializing the model + output["diffusers_version"] = __version__ + + self.dict_torch_dtype_to_str(output) + + return output + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + + Args: + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON string. + + Returns: + `str`: String containing all the attributes that make up this configuration instance in JSON format. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + """ + Save this instance to a JSON file. + + Args: + json_file_path (`str` or `os.PathLike`): + Path to the JSON file in which this configuration instance's parameters will be saved. + use_diff (`bool`, *optional*, defaults to `True`): + If set to `True`, only the difference between the config instance and the default `PretrainedConfig()` + is serialized to JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + def update(self, config_dict: Dict[str, Any]): + """ + Updates attributes of this class with attributes from `config_dict`. + + Args: + config_dict (`Dict[str, Any]`): Dictionary of attributes that should be updated for this class. + """ + for key, value in config_dict.items(): + setattr(self, key, value) + + def update_from_string(self, update_str: str): + """ + Updates attributes of this class with attributes from `update_str`. + + The expected format is ints, floats and strings as is, and for booleans use `true` or `false`. For example: + "n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index" + + The keys to change have to already exist in the config object. + + Args: + update_str (`str`): String with attributes that should be updated for this class. + + """ + + d = dict(x.split("=") for x in update_str.split(",")) + for k, v in d.items(): + if not hasattr(self, k): + raise ValueError(f"key {k} isn't in the original config dict") + + old_v = getattr(self, k) + if isinstance(old_v, bool): + if v.lower() in ["true", "1", "y", "yes"]: + v = True + elif v.lower() in ["false", "0", "n", "no"]: + v = False + else: + raise ValueError(f"can't derive true or false from {v} (key {k})") + elif isinstance(old_v, int): + v = int(v) + elif isinstance(old_v, float): + v = float(v) + elif not isinstance(old_v, str): + raise ValueError( + f"You can only update int, float, bool or string values in the config, got {v} for key {k}" + ) + + setattr(self, k, v) + + def dict_torch_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("torch_dtype", None) is not None and not isinstance(d["torch_dtype"], str): + d["torch_dtype"] = str(d["torch_dtype"]).split(".")[1] + for value in d.values(): + if isinstance(value, dict): + self.dict_torch_dtype_to_str(value) diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py new file mode 100644 index 0000000000..a7e2b3895f --- /dev/null +++ b/src/diffusers/modeling_utils.py @@ -0,0 +1,637 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Inc. team. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Callable, List, Optional, Tuple, Union + +import torch +from torch import Tensor, device + +from requests import HTTPError + +# CHANGE to diffusers.utils +from transformers.utils import ( + HUGGINGFACE_CO_RESOLVE_ENDPOINT, + EntryNotFoundError, + RepositoryNotFoundError, + RevisionNotFoundError, + cached_path, + hf_bucket_url, + is_offline_mode, + is_remote_url, + logging, +) + +from .configuration_utils import PretrainedConfig + + +WEIGHTS_NAME = "diffusion_model.pt" + + +logger = logging.get_logger(__name__) + + +def get_parameter_device(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).device + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].device + + +def get_parameter_dtype(parameter: torch.nn.Module): + try: + return next(parameter.parameters()).dtype + except StopIteration: + # For torch.nn.DataParallel compatibility in PyTorch 1.5 + + def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]: + tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)] + return tuples + + gen = parameter._named_members(get_members_fn=find_tensor_attributes) + first_tuple = next(gen) + return first_tuple[1].dtype + + +def load_state_dict(checkpoint_file: Union[str, os.PathLike]): + """ + Reads a PyTorch checkpoint file, returning properly formatted errors if they arise. + """ + try: + return torch.load(checkpoint_file, map_location="cpu") + except Exception as e: + try: + with open(checkpoint_file) as f: + if f.read().startswith("version"): + raise OSError( + "You seem to have cloned a repository without having git-lfs installed. Please install " + "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder " + "you cloned." + ) + else: + raise ValueError( + f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained " + "model. Make sure you have saved the model properly." + ) from e + except (UnicodeDecodeError, ValueError): + raise OSError( + f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' " + f"at '{checkpoint_file}'. " + "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True." + ) + + +def _load_state_dict_into_model(model_to_load, state_dict): + # Convert old format to new format if needed from a PyTorch state_dict + # copy state_dict so _load_from_state_dict can modify it + state_dict = state_dict.copy() + error_msgs = [] + + # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants + # so we need to apply the function recursively. + def load(module: torch.nn.Module, prefix=""): + args = (state_dict, prefix, {}, True, [], [], error_msgs) + module._load_from_state_dict(*args) + + for name, child in module._modules.items(): + if child is not None: + load(child, prefix + name + ".") + + load(model_to_load) + + return error_msgs + + +class PreTrainedModel(torch.nn.Module): + r""" + Base class for all models. + + [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading, + downloading and saving models as well as a few methods common to all models to: + + - resize the input embeddings, + - prune heads in the self-attention heads. + + Class attributes (overridden by derived classes): + + - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class + for this model architecture. + - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model, + taking as arguments: + + - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint. + - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model. + - **path** (`str`) -- A path to the TensorFlow checkpoint. + + - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived + classes of the same architecture adding modules on top of the base model. + - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization. + - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP + models, `pixel_values` for vision models and `input_values` for speech models). + """ + config_class = None + + def __init__(self, config: PretrainedConfig): + super().__init__() + if not isinstance(config, PretrainedConfig): + raise ValueError( + f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class " + "`PretrainedConfig`. To create a model from a pretrained model use " + f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`" + ) + # Save config and origin of the pretrained weights if given in model + self.config = config + self.name_or_path = config.name_or_path + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + model = cls(config, **kwargs) + + return model + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + save_function: Callable = torch.save, + push_to_hub: bool = False, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + `[`~PreTrainedModel.from_pretrained`]` class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + push_to_hub (`bool`, *optional*, defaults to `False`): + Whether or not to push your model to the Hugging Face model hub after saving it. + + + + Using `push_to_hub=True` will synchronize the repository you are pushing to with `save_directory`, + which requires `save_directory` to be a local clone of the repo you are pushing to if it's an existing + folder. Pass along `temp_dir=True` to use a temporary directory instead. + + + + kwargs: + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + model_to_save = self + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if is_main_process: + model_to_save.config.save_pretrained(save_directory) + + # Save the model + state_dict = model_to_save.state_dict() + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + if filename.startswith(WEIGHTS_NAME[:-4]) and os.path.isfile(full_filename) and is_main_process: + os.remove(full_filename) + + # Save the model + save_function(state_dict, os.path.join(save_directory, WEIGHTS_NAME)) + + logger.info(f"Model weights saved in {os.path.join(save_directory, WEIGHTS_NAME)}") + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], *model_args, **kwargs): + r""" + Instantiate a pretrained pytorch model from a pre-trained model configuration. + + The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train + the model, you should first set it back in training mode with `model.train()`. + + The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come + pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning + task. + + The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those + weights are discarded. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): + Can be either: + + - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co. + Valid model ids can be located at the root-level, like `bert-base-uncased`, or namespaced under a + user or organization name, like `dbmdz/bert-base-german-cased`. + - A path to a *directory* containing model weights saved using + [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`. + + config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*): + Can be either: + + - an instance of a class derived from [`PretrainedConfig`], + - a string or path valid as input to [`~PretrainedConfig.from_pretrained`]. + + Configuration for the model to use instead of an automatically loaded configuration. Configuration can + be automatically loaded when: + + - The model is a model provided by the library (loaded with the *model id* string of a pretrained + model). + - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the + save directory. + - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a + configuration JSON file named *config.json* is found in the directory. + cache_dir (`Union[str, os.PathLike]`, *optional*): + Path to a directory in which a downloaded pretrained model configuration should be cached if the + standard cache should not be used. + from_tf (`bool`, *optional*, defaults to `False`): + Load the model weights from a TensorFlow checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + from_flax (`bool`, *optional*, defaults to `False`): + Load the model weights from a Flax checkpoint save file (see docstring of + `pretrained_model_name_or_path` argument). + ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`): + Whether or not to raise an error if some of the weights from the checkpoint do not have the same size + as the weights of the model (if for instance, you are instantiating a model with 10 labels from a + checkpoint with 3 labels). + force_download (`bool`, *optional*, defaults to `False`): + Whether or not to force the (re-)download of the model weights and configuration files, overriding the + cached versions if they exist. + resume_download (`bool`, *optional*, defaults to `False`): + Whether or not to delete incompletely received files. Will attempt to resume the download if such a + file exists. + proxies (`Dict[str, str]`, *optional*): + A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128', + 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. + output_loading_info(`bool`, *optional*, defaults to `False`): + Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages. + local_files_only(`bool`, *optional*, defaults to `False`): + Whether or not to only look at local files (i.e., do not try to download the model). + use_auth_token (`str` or *bool*, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + mirror (`str`, *optional*): + Mirror source to accelerate downloads in China. If you are from China and have an accessibility + problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety. + Please refer to the mirror site for more information. + + kwargs (remaining dictionary of keyword arguments, *optional*): + Can be used to update the configuration object (after it being loaded) and initiate the model (e.g., + `output_attentions=True`). Behaves differently depending on whether a `config` is provided or + automatically loaded: + + - If a configuration is provided with `config`, `**kwargs` will be directly passed to the + underlying model's `__init__` method (we assume all relevant updates to the configuration have + already been done) + - If a configuration is not provided, `kwargs` will be first passed to the configuration class + initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that + corresponds to a configuration attribute will be used to override said attribute with the + supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute + will be passed to the underlying model's `__init__` function. + + + + Passing `use_auth_token=True`` is required when you want to use a private model. + + + + + + Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to + use this method in a firewalled environment. + + + + """ + cache_dir = kwargs.pop("cache_dir", None) + ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + output_loading_info = kwargs.pop("output_loading_info", False) + local_files_only = kwargs.pop("local_files_only", False) + use_auth_token = kwargs.pop("use_auth_token", None) + revision = kwargs.pop("revision", None) + mirror = kwargs.pop("mirror", None) + from_pipeline = kwargs.pop("_from_pipeline", None) + from_auto_class = kwargs.pop("_from_auto", False) + + user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class} + + if is_offline_mode() and not local_files_only: + logger.info("Offline mode: forcing local_files_only=True") + local_files_only = True + + # Load config if we don't provide a configuration + config_path = pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + revision=revision, + _from_auto=from_auto_class, + _from_pipeline=from_pipeline, + **kwargs, + ) + model_kwargs = kwargs + + # This variable will flag if we're loading a sharded checkpoint. In this case the archive file is just the + # Load model + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + if os.path.isdir(pretrained_model_name_or_path): + if os.path.isfile(os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME)): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, WEIGHTS_NAME) + else: + raise EnvironmentError( + f"Error no file named {WEIGHTS_NAME} found in directory {pretrained_model_name_or_path}." + ) + elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path): + archive_file = pretrained_model_name_or_path + else: + filename = WEIGHTS_NAME + + archive_file = hf_bucket_url( + pretrained_model_name_or_path, filename=filename, revision=revision, mirror=mirror + ) + + try: + # Load from URL or cache if already cached + resolved_archive_file = cached_path( + archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + use_auth_token=use_auth_token, + user_agent=user_agent, + ) + + except RepositoryNotFoundError: + raise EnvironmentError( + f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier " + "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a " + "token having permission to this repo with `use_auth_token` or log in with `huggingface-cli " + "login` and pass `use_auth_token=True`." + ) + except RevisionNotFoundError: + raise EnvironmentError( + f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for " + "this model name. Check the model page at " + f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions." + ) + except EntryNotFoundError: + raise EnvironmentError(f"{pretrained_model_name_or_path} does not appear to have a file named {filename}.") + except HTTPError as err: + raise EnvironmentError( + f"There was a specific connection error when trying to load {pretrained_model_name_or_path}:\n{err}" + ) + except ValueError: + raise EnvironmentError( + f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it" + f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a" + f" directory containing a file named {WEIGHTS_NAME} or" + " \nCheckout your internet connection or see how to run the library in" + " offline mode at 'https://huggingface.co/docs/transformers/installation#offline-mode'." + ) + except EnvironmentError: + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from " + "'https://huggingface.co/models', make sure you don't have a local directory with the same name. " + f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory " + f"containing a file named {WEIGHTS_NAME}" + ) + + if resolved_archive_file == archive_file: + logger.info(f"loading weights file {archive_file}") + else: + logger.info(f"loading weights file {archive_file} from cache at {resolved_archive_file}") + + state_dict = load_state_dict(resolved_archive_file) + # set dtype to instantiate the model under: + # 1. If torch_dtype is not None, we use that dtype + # 2. If torch_dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first + # weights entry - we assume all weights are of the same dtype + # we also may have config.torch_dtype available, but we won't rely on it till v5 + + config.name_or_path = pretrained_model_name_or_path + + model = cls(config, *model_args, **model_kwargs) + + # restore default dtype + model, missing_keys, unexpected_keys, mismatched_keys, error_msgs = cls._load_pretrained_model( + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=ignore_mismatched_sizes, + ) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.eval() + + if output_loading_info: + loading_info = { + "missing_keys": missing_keys, + "unexpected_keys": unexpected_keys, + "mismatched_keys": mismatched_keys, + "error_msgs": error_msgs, + } + return model, loading_info + + return model + + @classmethod + def _load_pretrained_model( + cls, + model, + state_dict, + resolved_archive_file, + pretrained_model_name_or_path, + ignore_mismatched_sizes=False, + ): + # Retrieve missing & unexpected_keys + model_state_dict = model.state_dict() + loaded_keys = [k for k in state_dict.keys()] + + expected_keys = list(model_state_dict.keys()) + + original_loaded_keys = loaded_keys + + missing_keys = list(set(expected_keys) - set(loaded_keys)) + unexpected_keys = list(set(loaded_keys) - set(expected_keys)) + + # Make sure we are able to load base models as well as derived models (with heads) + model_to_load = model + + def _find_mismatched_keys( + state_dict, + model_state_dict, + loaded_keys, + ignore_mismatched_sizes, + ): + mismatched_keys = [] + if ignore_mismatched_sizes: + for checkpoint_key in loaded_keys: + model_key = checkpoint_key + + if ( + model_key in model_state_dict + and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape + ): + mismatched_keys.append( + (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape) + ) + del state_dict[checkpoint_key] + return mismatched_keys + + if state_dict is not None: + # Whole checkpoint + mismatched_keys = _find_mismatched_keys( + state_dict, + model_state_dict, + original_loaded_keys, + ignore_mismatched_sizes, + ) + error_msgs = _load_state_dict_into_model(model_to_load, state_dict) + + if len(error_msgs) > 0: + error_msg = "\n\t".join(error_msgs) + if "size mismatch" in error_msg: + error_msg += ( + "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method." + ) + raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}") + + if len(unexpected_keys) > 0: + logger.warning( + f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when" + f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are" + f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or" + " with another architecture (e.g. initializing a BertForSequenceClassification model from a" + " BertForPreTraining model).\n- This IS NOT expected if you are initializing" + f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical" + " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)." + ) + else: + logger.info(f"All model checkpoint weights were used when initializing {model.__class__.__name__}.\n") + if len(missing_keys) > 0: + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably" + " TRAIN this model on a down-stream task to be able to use it for predictions and inference." + ) + elif len(mismatched_keys) == 0: + logger.info( + f"All the weights of {model.__class__.__name__} were initialized from the model checkpoint at" + f" {pretrained_model_name_or_path}.\nIf your task is similar to the task the model of the checkpoint" + f" was trained on, you can already use {model.__class__.__name__} for predictions without further" + " training." + ) + if len(mismatched_keys) > 0: + mismatched_warning = "\n".join( + [ + f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated" + for key, shape1, shape2 in mismatched_keys + ] + ) + logger.warning( + f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at" + f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not" + f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able" + " to use it for predictions and inference." + ) + + return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs + + @property + def device(self) -> device: + """ + `torch.device`: The device on which the module is (assuming that all the module parameters are on the same + device). + """ + return get_parameter_device(self) + + @property + def dtype(self) -> torch.dtype: + """ + `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int: + """ + Get number of (optionally, trainable or non-embeddings) parameters in the module. + + Args: + only_trainable (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of trainable parameters + + exclude_embeddings (`bool`, *optional*, defaults to `False`): + Whether or not to return only the number of non-embeddings parameters + + Returns: + `int`: The number of parameters. + """ + + if exclude_embeddings: + embedding_param_names = [ + f"{name}.weight" + for name, module_type in self.named_modules() + if isinstance(module_type, torch.nn.Embedding) + ] + non_embedding_parameters = [ + parameter for name, parameter in self.named_parameters() if name not in embedding_param_names + ] + return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable) + else: + return sum(p.numel() for p in self.parameters() if p.requires_grad or not only_trainable) diff --git a/src/diffusers/models/__init__.py b/src/diffusers/models/__init__.py index 4e4f83ea91..364653676d 100644 --- a/src/diffusers/models/__init__.py +++ b/src/diffusers/models/__init__.py @@ -1,5 +1,19 @@ # flake8: noqa # There's no way to ignore "F401 '...' imported but unused" warnings in this -# module, but to preserve other warnings. So, don't check this module at all +# module, but to preserve other warnings. So, don't check this module at all. -from .unet import UNetModel +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from . import unet diff --git a/src/diffusers/models/unet.py b/src/diffusers/models/unet.py deleted file mode 100644 index eff3474127..0000000000 --- a/src/diffusers/models/unet.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2022 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class UNetModel: - def __init__(self, config): - self.config = config - print("I can diffuse!") diff --git a/src/diffusers/models/unet/__init__.py b/src/diffusers/models/unet/__init__.py new file mode 100644 index 0000000000..01b2a02857 --- /dev/null +++ b/src/diffusers/models/unet/__init__.py @@ -0,0 +1,6 @@ +# flake8: noqa +# There's no way to ignore "F401 '...' imported but unused" warnings in this +# module, but to preserve other warnings. So, don't check this module at all + +from .configuration_unet import UNetConfig +from .modeling_unet import GaussianDiffusion, UNetModel diff --git a/src/diffusers/models/unet/configuration_unet.py b/src/diffusers/models/unet/configuration_unet.py new file mode 100644 index 0000000000..a8a1e12c00 --- /dev/null +++ b/src/diffusers/models/unet/configuration_unet.py @@ -0,0 +1,45 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# helpers functions + +# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +from ...configuration_utils import PretrainedConfig + + +class UNetConfig(PretrainedConfig): + model_type = "unet" + + def __init__( + self, + dim=64, + dim_mults=(1, 2, 4, 8), + init_dim=None, + out_dim=None, + channels=3, + with_time_emb=True, + resnet_block_groups=8, + learned_variance=False, + **kwargs, + ): + super().__init__(**kwargs) + self.dim = dim + self.dim_mults = dim_mults + self.init_dim = init_dim + self.out_dim = out_dim + self.channels = channels + self.with_time_emb = with_time_emb + self.resnet_block_groups = resnet_block_groups + self.learned_variance = learned_variance diff --git a/src/diffusers/models/unet/modeling_unet.py b/src/diffusers/models/unet/modeling_unet.py new file mode 100644 index 0000000000..23c0921633 --- /dev/null +++ b/src/diffusers/models/unet/modeling_unet.py @@ -0,0 +1,728 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +# limitations under the License. + +# helpers functions + +import copy +import math +from functools import partial +from inspect import isfunction +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch import einsum, nn +from torch.cuda.amp import GradScaler, autocast +from torch.optim import Adam +from torch.utils import data + +from einops import rearrange +from PIL import Image +from torchvision import transforms, utils +from tqdm import tqdm + +from ...modeling_utils import PreTrainedModel +from .configuration_unet import UNetConfig + + +# NOTE: the following file is completely copied from https://github.com/lucidrains/denoising-diffusion-pytorch/blob/master/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py + + + + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def cycle(dl): + while True: + for data_dl in dl: + yield data_dl + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def normalize_to_neg_one_to_one(img): + return img * 2 - 1 + + +def unnormalize_to_zero_to_one(t): + return (t + 1) * 0.5 + + +# small helper modules + + +class EMA: + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose2d(dim, dim, 4, 2, 1) + + +def Downsample(dim): + return nn.Conv2d(dim, dim, 4, 2, 1) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.g = nn.Parameter(torch.ones(1, dim, 1, 1)) + self.b = nn.Parameter(torch.zeros(1, dim, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.g + self.b + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x): + x = self.norm(x) + return self.fn(x) + + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv2d(dim, dim_out, 3, padding=1) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + x = self.act(x) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential(nn.SiLU(), nn.Linear(time_emb_dim, dim_out * 2)) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv2d(dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp) and exists(time_emb): + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, "b c -> b c 1 1") + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + + self.to_out = nn.Sequential(nn.Conv2d(hidden_dim, dim, 1), LayerNorm(dim)) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum("b h d n, b h e n -> b h d e", k, v) + + out = torch.einsum("b h d e, b h d n -> b h e n", context, q) + out = rearrange(out, "b h c (x y) -> b (h c) x y", h=self.heads, x=h, y=w) + return self.to_out(out) + + +class Attention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head**-0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = map(lambda t: rearrange(t, "b (h c) x y -> b h c (x y)", h=self.heads), qkv) + q = q * self.scale + + sim = einsum("b h d i, b h d j -> b h i j", q, k) + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + out = einsum("b h i j, b h d j -> b h i d", attn, v) + out = rearrange(out, "b h (x y) d -> b (h d) x y", x=h, y=w) + return self.to_out(out) + + +class UNetModel(PreTrainedModel): + + config_class = UNetConfig + + def __init__(self, config): + super().__init__(config) + + init_dim = None + out_dim = None + channels = 3 + with_time_emb = True + resnet_block_groups = 8 + learned_variance = False + + # determine dimensions + + dim_mults = config.dim_mults + dim = config.dim + self.channels = config.channels + + init_dim = default(init_dim, dim // 3 * 2) + self.init_conv = nn.Conv2d(channels, init_dim, 7, padding=3) + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + block_klass = partial(ResnetBlock, groups=resnet_block_groups) + + # time embeddings + + if with_time_emb: + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), nn.Linear(dim, time_dim), nn.GELU(), nn.Linear(time_dim, time_dim) + ) + else: + time_dim = None + self.time_mlp = None + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + num_resolutions = len(in_out) + + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append( + nn.ModuleList( + [ + block_klass(dim_in, dim_out, time_emb_dim=time_dim), + block_klass(dim_out, dim_out, time_emb_dim=time_dim), + Residual(PreNorm(dim_out, LinearAttention(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity(), + ] + ) + ) + + mid_dim = dims[-1] + self.mid_block1 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + self.mid_attn = Residual(PreNorm(mid_dim, Attention(mid_dim))) + self.mid_block2 = block_klass(mid_dim, mid_dim, time_emb_dim=time_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): + is_last = ind >= (num_resolutions - 1) + + self.ups.append( + nn.ModuleList( + [ + block_klass(dim_out * 2, dim_in, time_emb_dim=time_dim), + block_klass(dim_in, dim_in, time_emb_dim=time_dim), + Residual(PreNorm(dim_in, LinearAttention(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity(), + ] + ) + ) + + default_out_dim = channels * (1 if not learned_variance else 2) + self.out_dim = default(out_dim, default_out_dim) + + self.final_conv = nn.Sequential(block_klass(dim, dim), nn.Conv2d(dim, self.out_dim, 1)) + + def forward(self, x, time): + x = self.init_conv(x) + + t = self.time_mlp(time) if exists(self.time_mlp) else None + + h = [] + + for block1, block2, attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = attn(x) + h.append(x) + x = downsample(x) + + x = self.mid_block1(x, t) + x = self.mid_attn(x) + x = self.mid_block2(x, t) + + for block1, block2, attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = attn(x) + x = upsample(x) + + return self.final_conv(x) + + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def noise_like(shape, device, repeat=False): + def repeat_noise(): + return torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + + def noise(): + return torch.randn(shape, device=device) + + return repeat_noise() if repeat else noise() + + +def linear_beta_schedule(timesteps): + scale = 1000 / timesteps + beta_start = scale * 0.0001 + beta_end = scale * 0.02 + return torch.linspace(beta_start, beta_end, timesteps, dtype=torch.float64) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos(((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + channels=3, + timesteps=1000, + loss_type="l1", + objective="pred_noise", + beta_schedule="cosine", + ): + super().__init__() + assert not (type(self) == GaussianDiffusion and denoise_fn.channels != denoise_fn.out_dim) + + self.channels = channels + self.image_size = image_size + self.denoise_fn = denoise_fn + self.objective = objective + + if beta_schedule == "linear": + betas = linear_beta_schedule(timesteps) + elif beta_schedule == "cosine": + betas = cosine_beta_schedule(timesteps) + else: + raise ValueError(f"unknown beta schedule {beta_schedule}") + + alphas = 1.0 - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.0) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # helper function to register buffer from float64 to float32 + + def register_buffer(name, val): + self.register_buffer(name, val.to(torch.float32)) + + register_buffer("betas", betas) + register_buffer("alphas_cumprod", alphas_cumprod) + register_buffer("alphas_cumprod_prev", alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod)) + register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1.0 - alphas_cumprod)) + register_buffer("log_one_minus_alphas_cumprod", torch.log(1.0 - alphas_cumprod)) + register_buffer("sqrt_recip_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod)) + register_buffer("sqrt_recipm1_alphas_cumprod", torch.sqrt(1.0 / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer("posterior_variance", posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer("posterior_log_variance_clipped", torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer("posterior_mean_coef1", betas * torch.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)) + register_buffer( + "posterior_mean_coef2", (1.0 - alphas_cumprod_prev) * torch.sqrt(alphas) / (1.0 - alphas_cumprod) + ) + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_output = self.denoise_fn(x, t) + + if self.objective == "pred_noise": + x_start = self.predict_start_from_noise(x, t=t, noise=model_output) + elif self.objective == "pred_x0": + x_start = model_output + else: + raise ValueError(f"unknown objective {self.objective}") + + if clip_denoised: + x_start.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_start, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance(x=x, t=t, clip_denoised=clip_denoised) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))) + result = model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + return result + + @torch.no_grad() + def p_sample_loop(self, shape): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + + for i in tqdm( + reversed(range(0, self.num_timesteps)), desc="sampling loop time step", total=self.num_timesteps + ): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + + img = unnormalize_to_zero_to_one(img) + return img + + @torch.no_grad() + def sample(self, batch_size=16): + image_size = self.image_size + channels = self.channels + return self.p_sample_loop((batch_size, channels, image_size, image_size)) + + @torch.no_grad() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in tqdm(reversed(range(0, t)), desc="interpolation sample time step", total=t): + img = self.p_sample(img, torch.full((b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + @property + def loss_fn(self): + if self.loss_type == "l1": + return F.l1_loss + elif self.loss_type == "l2": + return F.mse_loss + else: + raise ValueError(f"invalid loss type {self.loss_type}") + + def p_losses(self, x_start, t, noise=None): + b, c, h, w = x_start.shape + noise = default(noise, lambda: torch.randn_like(x_start)) + + x = self.q_sample(x_start=x_start, t=t, noise=noise) + model_out = self.denoise_fn(x, t) + + if self.objective == "pred_noise": + target = noise + elif self.objective == "pred_x0": + target = x_start + else: + raise ValueError(f"unknown objective {self.objective}") + + loss = self.loss_fn(model_out, target) + return loss + + def forward(self, img, *args, **kwargs): + b, _, h, w, device, img_size, = ( + *img.shape, + img.device, + self.image_size, + ) + assert h == img_size and w == img_size, f"height and width of image must be {img_size}" + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + + img = normalize_to_neg_one_to_one(img) + return self.p_losses(img, t, *args, **kwargs) + + +# dataset classes + + +class Dataset(data.Dataset): + def __init__(self, folder, image_size, exts=["jpg", "jpeg", "png"]): + super().__init__() + self.folder = folder + self.image_size = image_size + self.paths = [p for ext in exts for p in Path(f"{folder}").glob(f"**/*.{ext}")] + + self.transform = transforms.Compose( + [ + transforms.Resize(image_size), + transforms.RandomHorizontalFlip(), + transforms.CenterCrop(image_size), + transforms.ToTensor(), + ] + ) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + img = Image.open(path) + return self.transform(img) + + +# trainer class + + +class Trainer(object): + def __init__( + self, + diffusion_model, + folder, + *, + ema_decay=0.995, + image_size=128, + train_batch_size=32, + train_lr=1e-4, + train_num_steps=100000, + gradient_accumulate_every=2, + amp=False, + step_start_ema=2000, + update_ema_every=10, + save_and_sample_every=1000, + results_folder="./results", + ): + super().__init__() + self.model = diffusion_model + self.ema = EMA(ema_decay) + self.ema_model = copy.deepcopy(self.model) + self.update_ema_every = update_ema_every + + self.step_start_ema = step_start_ema + self.save_and_sample_every = save_and_sample_every + + self.batch_size = train_batch_size + self.image_size = diffusion_model.image_size + self.gradient_accumulate_every = gradient_accumulate_every + self.train_num_steps = train_num_steps + + self.ds = Dataset(folder, image_size) + self.dl = cycle(data.DataLoader(self.ds, batch_size=train_batch_size, shuffle=True, pin_memory=True)) + self.opt = Adam(diffusion_model.parameters(), lr=train_lr) + + self.step = 0 + + self.amp = amp + self.scaler = GradScaler(enabled=amp) + + self.results_folder = Path(results_folder) + self.results_folder.mkdir(exist_ok=True) + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + def step_ema(self): + if self.step < self.step_start_ema: + self.reset_parameters() + return + self.ema.update_model_average(self.ema_model, self.model) + + def save(self, milestone): + data = { + "step": self.step, + "model": self.model.state_dict(), + "ema": self.ema_model.state_dict(), + "scaler": self.scaler.state_dict(), + } + torch.save(data, str(self.results_folder / f"model-{milestone}.pt")) + + def load(self, milestone): + data = torch.load(str(self.results_folder / f"model-{milestone}.pt")) + + self.step = data["step"] + self.model.load_state_dict(data["model"]) + self.ema_model.load_state_dict(data["ema"]) + self.scaler.load_state_dict(data["scaler"]) + + def train(self): + with tqdm(initial=self.step, total=self.train_num_steps) as pbar: + + while self.step < self.train_num_steps: + for i in range(self.gradient_accumulate_every): + data = next(self.dl).cuda() + + with autocast(enabled=self.amp): + loss = self.model(data) + self.scaler.scale(loss / self.gradient_accumulate_every).backward() + + pbar.set_description(f"loss: {loss.item():.4f}") + + self.scaler.step(self.opt) + self.scaler.update() + self.opt.zero_grad() + + if self.step % self.update_ema_every == 0: + self.step_ema() + + if self.step != 0 and self.step % self.save_and_sample_every == 0: + self.ema_model.eval() + + milestone = self.step // self.save_and_sample_every + batches = num_to_groups(36, self.batch_size) + all_images_list = list(map(lambda n: self.ema_model.sample(batch_size=n), batches)) + all_images = torch.cat(all_images_list, dim=0) + utils.save_image(all_images, str(self.results_folder / f"sample-{milestone}.png"), nrow=6) + self.save(milestone) + + self.step += 1 + pbar.update(1) + + print("training complete") diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py new file mode 100755 index 0000000000..eb9debd18b --- /dev/null +++ b/tests/test_modeling_utils.py @@ -0,0 +1,62 @@ +# coding=utf-8 +# Copyright 2022 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import random +import tempfile +import unittest + +import torch + +from diffusers import UNetConfig, UNetModel + + +global_rng = random.Random() + + +def floats_tensor(shape, scale=1.0, rng=None, name=None): + """Creates a random float32 tensor""" + if rng is None: + rng = global_rng + + total_dims = 1 + for dim in shape: + total_dims *= dim + + values = [] + for _ in range(total_dims): + values.append(rng.random() * scale) + + return torch.tensor(data=values, dtype=torch.float).view(shape).contiguous() + + +class ModelTesterMixin(unittest.TestCase): + def test_from_pretrained_save_pretrained(self): + config = UNetConfig(dim=8, dim_mults=(1, 2), resnet_block_groups=2) + model = UNetModel(config) + + with tempfile.TemporaryDirectory() as tmpdirname: + model.save_pretrained(tmpdirname) + new_model = UNetModel.from_pretrained(tmpdirname) + + batch_size = 1 + num_channels = 3 + sizes = (32, 32) + noise = floats_tensor((batch_size, num_channels) + sizes) + time_step = torch.tensor([10]) + + image = model(noise, time_step) + new_image = new_model(noise, time_step) + + assert (image - new_image).abs().sum() < 1e-5, "Models don't give the same forward pass"