mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add DREAM training (#6381)
A new function compute_dream_and_update_latents has been added to the training utilities that allows you to do DREAM rectified training in line with the paper https://arxiv.org/abs/2312.00210. The method can be used with an extra argument in the train_text_to_image.py script. Co-authored-by: Jimmy <39@🇺🇸.com>
This commit is contained in:
committed by
GitHub
parent
8e4ca1b6b2
commit
9d16daaf64
@@ -170,6 +170,11 @@ For our small Pokemons dataset, the effects of Min-SNR weighting strategy might
|
||||
|
||||
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
|
||||
|
||||
#### Training with DREAM
|
||||
|
||||
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
|
||||
|
||||
|
||||
## Training with LoRA
|
||||
|
||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
|
||||
@@ -45,7 +45,7 @@ from transformers.utils import ContextManagers
|
||||
import diffusers
|
||||
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import EMAModel, compute_snr
|
||||
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
|
||||
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@@ -361,6 +361,20 @@ def parse_args():
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dream_training",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Use the DREAM training method, which makes training more efficient and accurate at the ",
|
||||
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dream_detail_preservation",
|
||||
type=float,
|
||||
default=1.0,
|
||||
help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
@@ -948,6 +962,18 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
if args.dream_training:
|
||||
noisy_latents, target = compute_dream_and_update_latents(
|
||||
unet,
|
||||
noise_scheduler,
|
||||
timesteps,
|
||||
noise,
|
||||
noisy_latents,
|
||||
target,
|
||||
encoder_hidden_states,
|
||||
args.dream_detail_preservation,
|
||||
)
|
||||
|
||||
# Predict the noise residual and compute loss
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
import contextlib
|
||||
import copy
|
||||
import random
|
||||
from typing import Any, Dict, Iterable, List, Optional, Union
|
||||
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .models import UNet2DConditionModel
|
||||
from .schedulers import SchedulerMixin
|
||||
from .utils import (
|
||||
convert_state_dict_to_diffusers,
|
||||
convert_state_dict_to_peft,
|
||||
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
|
||||
return interpolation_mode
|
||||
|
||||
|
||||
def compute_dream_and_update_latents(
|
||||
unet: UNet2DConditionModel,
|
||||
noise_scheduler: SchedulerMixin,
|
||||
timesteps: torch.Tensor,
|
||||
noise: torch.Tensor,
|
||||
noisy_latents: torch.Tensor,
|
||||
target: torch.Tensor,
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
dream_detail_preservation: float = 1.0,
|
||||
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
||||
"""
|
||||
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
|
||||
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
|
||||
forward step without gradients.
|
||||
|
||||
Args:
|
||||
`unet`: The state unet to use to make a prediction.
|
||||
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
||||
`timesteps`: The timesteps for the noise_scheduler to user.
|
||||
`noise`: A tensor of noise in the shape of noisy_latents.
|
||||
`noisy_latents`: Previously noise latents from the training loop.
|
||||
`target`: The ground-truth tensor to predict after eps is removed.
|
||||
`encoder_hidden_states`: Text embeddings from the text model.
|
||||
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
||||
See reference.
|
||||
|
||||
Returns:
|
||||
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
||||
"""
|
||||
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
||||
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
||||
|
||||
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
||||
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
||||
|
||||
pred = None
|
||||
with torch.no_grad():
|
||||
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
noisy_latents, target = (None, None)
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
predicted_noise = pred
|
||||
delta_noise = (noise - predicted_noise).detach()
|
||||
delta_noise.mul_(dream_lambda)
|
||||
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
||||
target = target.add(delta_noise)
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
return noisy_latents, target
|
||||
|
||||
|
||||
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
||||
r"""
|
||||
Returns:
|
||||
|
||||
Reference in New Issue
Block a user