1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Training] SD3 training fixes (#8917)

* SD3 training fixes

Co-authored-by: bghira <59658056+bghira@users.noreply.github.com>

* rewrite noise addition part to respect the eqn.

* styler

* Update examples/dreambooth/README_sd3.md

Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>

---------

Co-authored-by: bghira <59658056+bghira@users.noreply.github.com>
Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
Sayak Paul
2024-07-21 16:24:04 +05:30
committed by GitHub
parent 56e772ab7e
commit 1a8b3c2ee8
3 changed files with 36 additions and 9 deletions

View File

@@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
## Other notes
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.

View File

@@ -523,6 +523,13 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument(
"--optimizer",
type=str,
@@ -1636,7 +1643,7 @@ def main(args):
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
@@ -1656,8 +1663,9 @@ def main(args):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Predict the noise residual
model_pred = transformer(
@@ -1670,14 +1678,18 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = model_input
if args.precondition_outputs:
target = model_input
else:
target = noise - model_input
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.

View File

@@ -494,6 +494,13 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--precondition_outputs",
type=int,
default=1,
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
"model `target` is calculated.",
)
parser.add_argument(
"--optimizer",
type=str,
@@ -1549,7 +1556,7 @@ def main(args):
# Convert images to latent space
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
model_input = model_input.to(dtype=weight_dtype)
# Sample noise that we'll add to the latents
@@ -1569,8 +1576,9 @@ def main(args):
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
# Add noise according to flow matching.
# zt = (1 - texp) * x + texp * z1
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
# Predict the noise residual
if not args.train_text_encoder:
@@ -1598,13 +1606,18 @@ def main(args):
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
# Preconditioning of the model outputs.
model_pred = model_pred * (-sigmas) + noisy_model_input
if args.precondition_outputs:
model_pred = model_pred * (-sigmas) + noisy_model_input
# these weighting schemes use a uniform timestep sampling
# and instead post-weight the loss
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
# flow matching loss
target = model_input
if args.precondition_outputs:
target = model_input
else:
target = noise - model_input
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.