From 9d16daaf640462a0580dd1d503e71d246809a09a Mon Sep 17 00:00:00 2001 From: "39th president of the United States, probably" <110263573+AmericanPresidentJimmyCarter@users.noreply.github.com> Date: Fri, 26 Apr 2024 21:49:15 -0400 Subject: [PATCH] Add DREAM training (#6381) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A new function compute_dream_and_update_latents has been added to the training utilities that allows you to do DREAM rectified training in line with the paper https://arxiv.org/abs/2312.00210. The method can be used with an extra argument in the train_text_to_image.py script. Co-authored-by: Jimmy <39@🇺🇸.com> --- examples/text_to_image/README.md | 5 ++ examples/text_to_image/train_text_to_image.py | 28 ++++++++- src/diffusers/training_utils.py | 57 ++++++++++++++++++- 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/README.md b/examples/text_to_image/README.md index f2931d3f34..fd6e50bc37 100644 --- a/examples/text_to_image/README.md +++ b/examples/text_to_image/README.md @@ -170,6 +170,11 @@ For our small Pokemons dataset, the effects of Min-SNR weighting strategy might Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds. +#### Training with DREAM + +We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper. + + ## Training with LoRA Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 84f4c6514c..aa704ba8ca 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -45,7 +45,7 @@ from transformers.utils import ContextManagers import diffusers from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel, compute_snr +from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card from diffusers.utils.import_utils import is_xformers_available @@ -361,6 +361,20 @@ def parse_args(): help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " "More details here: https://arxiv.org/abs/2303.09556.", ) + parser.add_argument( + "--dream_training", + action="store_true", + help=( + "Use the DREAM training method, which makes training more efficient and accurate at the ", + "expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210", + ), + ) + parser.add_argument( + "--dream_detail_preservation", + type=float, + default=1.0, + help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)", + ) parser.add_argument( "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." ) @@ -948,6 +962,18 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + if args.dream_training: + noisy_latents, target = compute_dream_and_update_latents( + unet, + noise_scheduler, + timesteps, + noise, + noisy_latents, + target, + encoder_hidden_states, + args.dream_detail_preservation, + ) + # Predict the noise residual and compute loss model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] diff --git a/src/diffusers/training_utils.py b/src/diffusers/training_utils.py index 25e02a3d14..b617dd2eef 100644 --- a/src/diffusers/training_utils.py +++ b/src/diffusers/training_utils.py @@ -1,12 +1,13 @@ import contextlib import copy import random -from typing import Any, Dict, Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import numpy as np import torch from .models import UNet2DConditionModel +from .schedulers import SchedulerMixin from .utils import ( convert_state_dict_to_diffusers, convert_state_dict_to_peft, @@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str): return interpolation_mode +def compute_dream_and_update_latents( + unet: UNet2DConditionModel, + noise_scheduler: SchedulerMixin, + timesteps: torch.Tensor, + noise: torch.Tensor, + noisy_latents: torch.Tensor, + target: torch.Tensor, + encoder_hidden_states: torch.Tensor, + dream_detail_preservation: float = 1.0, +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. + DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra + forward step without gradients. + + Args: + `unet`: The state unet to use to make a prediction. + `noise_scheduler`: The noise scheduler used to add noise for the given timestep. + `timesteps`: The timesteps for the noise_scheduler to user. + `noise`: A tensor of noise in the shape of noisy_latents. + `noisy_latents`: Previously noise latents from the training loop. + `target`: The ground-truth tensor to predict after eps is removed. + `encoder_hidden_states`: Text embeddings from the text model. + `dream_detail_preservation`: A float value that indicates detail preservation level. + See reference. + + Returns: + `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. + """ + alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. + dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation + + pred = None + with torch.no_grad(): + pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + noisy_latents, target = (None, None) + if noise_scheduler.config.prediction_type == "epsilon": + predicted_noise = pred + delta_noise = (noise - predicted_noise).detach() + delta_noise.mul_(dream_lambda) + noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) + target = target.add(delta_noise) + elif noise_scheduler.config.prediction_type == "v_prediction": + raise NotImplementedError("DREAM has not been implemented for v-prediction") + else: + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + + return noisy_latents, target + + def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]: r""" Returns: