1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

SNR gamma fixes for v_prediction training (#5106)

Co-authored-by: bghira <bghira@users.github.com>
This commit is contained in:
Bagheera
2023-09-20 13:18:56 -07:00
committed by GitHub
parent 914586f5b6
commit 24563ca654
5 changed files with 15 additions and 0 deletions

View File

@@ -908,6 +908,9 @@ def main():
if args.snr_gamma is not None:
snr = jnp.array(compute_snr(timesteps))
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
snr_loss_weights = snr_loss_weights + 1
loss = loss * snr_loss_weights
loss = loss.mean()

View File

@@ -875,6 +875,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -955,6 +955,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -786,6 +786,9 @@ def main():
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.

View File

@@ -1075,6 +1075,9 @@ def main(args):
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
if noise_scheduler.config.prediction_type == "v_prediction":
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
mse_loss_weights = mse_loss_weights + 1
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.