mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Training] SD3 training fixes (#8917)
* SD3 training fixes Co-authored-by: bghira <59658056+bghira@users.noreply.github.com> * rewrite noise addition part to respect the eqn. * styler * Update examples/dreambooth/README_sd3.md Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com> --------- Co-authored-by: bghira <59658056+bghira@users.noreply.github.com> Co-authored-by: Kashif Rasul <kashif.rasul@gmail.com>
This commit is contained in:
@@ -183,4 +183,6 @@ accelerate launch train_dreambooth_lora_sd3.py \
|
||||
|
||||
## Other notes
|
||||
|
||||
We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
|
||||
1. We default to the "logit_normal" weighting scheme for the loss following the SD3 paper. Thanks to @bghira for helping us discover that for other weighting schemes supported from the training script, training may incur numerical instabilities.
|
||||
2. Thanks to `bghira`, `JinxuXiang`, and `bendanzzc` for helping us discover a bug in how VAE encoding was being done previously. This has been fixed in [#8917](https://github.com/huggingface/diffusers/pull/8917).
|
||||
3. Additionally, we now have the option to control if we want to apply preconditioning to the model outputs via a `--precondition_outputs` CLI arg. It affects how the model `target` is calculated as well.
|
||||
@@ -523,6 +523,13 @@ def parse_args(input_args=None):
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precondition_outputs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
|
||||
"model `target` is calculated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
@@ -1636,7 +1643,7 @@ def main(args):
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -1656,8 +1663,9 @@ def main(args):
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
# zt = (1 - texp) * x + texp * z1
|
||||
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
@@ -1670,14 +1678,18 @@ def main(args):
|
||||
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
if args.precondition_outputs:
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
# flow matching loss
|
||||
target = model_input
|
||||
if args.precondition_outputs:
|
||||
target = model_input
|
||||
else:
|
||||
target = noise - model_input
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
|
||||
@@ -494,6 +494,13 @@ def parse_args(input_args=None):
|
||||
default=1.29,
|
||||
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--precondition_outputs",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how "
|
||||
"model `target` is calculated.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
@@ -1549,7 +1556,7 @@ def main(args):
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor
|
||||
model_input = model_input.to(dtype=weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
@@ -1569,8 +1576,9 @@ def main(args):
|
||||
timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device)
|
||||
|
||||
# Add noise according to flow matching.
|
||||
# zt = (1 - texp) * x + texp * z1
|
||||
sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype)
|
||||
noisy_model_input = sigmas * noise + (1.0 - sigmas) * model_input
|
||||
noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
@@ -1598,13 +1606,18 @@ def main(args):
|
||||
|
||||
# Follow: Section 5 of https://arxiv.org/abs/2206.00364.
|
||||
# Preconditioning of the model outputs.
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
if args.precondition_outputs:
|
||||
model_pred = model_pred * (-sigmas) + noisy_model_input
|
||||
|
||||
# these weighting schemes use a uniform timestep sampling
|
||||
# and instead post-weight the loss
|
||||
weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas)
|
||||
|
||||
# flow matching loss
|
||||
target = model_input
|
||||
if args.precondition_outputs:
|
||||
target = model_input
|
||||
else:
|
||||
target = noise - model_input
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
|
||||
Reference in New Issue
Block a user