1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add DREAM training (#6381)

A new function compute_dream_and_update_latents has been added to the
training utilities that allows you to do DREAM rectified training in line
with the paper https://arxiv.org/abs/2312.00210. The method can be used
with an extra argument in the train_text_to_image.py script.

Co-authored-by: Jimmy <39@🇺🇸.com>
This commit is contained in:
39th president of the United States, probably
2024-04-26 21:49:15 -04:00
committed by GitHub
parent 8e4ca1b6b2
commit 9d16daaf64
3 changed files with 88 additions and 2 deletions

View File

@@ -170,6 +170,11 @@ For our small Pokemons dataset, the effects of Min-SNR weighting strategy might
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
#### Training with DREAM
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
## Training with LoRA
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

View File

@@ -45,7 +45,7 @@ from transformers.utils import ContextManagers
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel, compute_snr
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
@@ -361,6 +361,20 @@ def parse_args():
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--dream_training",
action="store_true",
help=(
"Use the DREAM training method, which makes training more efficient and accurate at the ",
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
),
)
parser.add_argument(
"--dream_detail_preservation",
type=float,
default=1.0,
help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
@@ -948,6 +962,18 @@ def main():
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.dream_training:
noisy_latents, target = compute_dream_and_update_latents(
unet,
noise_scheduler,
timesteps,
noise,
noisy_latents,
target,
encoder_hidden_states,
args.dream_detail_preservation,
)
# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]

View File

@@ -1,12 +1,13 @@
import contextlib
import copy
import random
from typing import Any, Dict, Iterable, List, Optional, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import torch
from .models import UNet2DConditionModel
from .schedulers import SchedulerMixin
from .utils import (
convert_state_dict_to_diffusers,
convert_state_dict_to_peft,
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
return interpolation_mode
def compute_dream_and_update_latents(
unet: UNet2DConditionModel,
noise_scheduler: SchedulerMixin,
timesteps: torch.Tensor,
noise: torch.Tensor,
noisy_latents: torch.Tensor,
target: torch.Tensor,
encoder_hidden_states: torch.Tensor,
dream_detail_preservation: float = 1.0,
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
forward step without gradients.
Args:
`unet`: The state unet to use to make a prediction.
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
`timesteps`: The timesteps for the noise_scheduler to user.
`noise`: A tensor of noise in the shape of noisy_latents.
`noisy_latents`: Previously noise latents from the training loop.
`target`: The ground-truth tensor to predict after eps is removed.
`encoder_hidden_states`: Text embeddings from the text model.
`dream_detail_preservation`: A float value that indicates detail preservation level.
See reference.
Returns:
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
"""
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
pred = None
with torch.no_grad():
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
noisy_latents, target = (None, None)
if noise_scheduler.config.prediction_type == "epsilon":
predicted_noise = pred
delta_noise = (noise - predicted_noise).detach()
delta_noise.mul_(dream_lambda)
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
target = target.add(delta_noise)
elif noise_scheduler.config.prediction_type == "v_prediction":
raise NotImplementedError("DREAM has not been implemented for v-prediction")
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
return noisy_latents, target
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
r"""
Returns: