mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Min-SNR Gamma: follow-up fix for zero-terminal SNR models on v-prediction or epsilon (#5177)
* merge with main * fix flax example * fix onnx example --------- Co-authored-by: bghira <bghira@users.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -907,10 +907,17 @@ def main():
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
snr_loss_weights = snr_loss_weights + 1
|
||||
snr_loss_weights = base_weights + 1
|
||||
else:
|
||||
# Epsilon and sample prediction use the base weights.
|
||||
snr_loss_weights = base_weights
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
snr_loss_weights[snr == 0] = 1.0
|
||||
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -801,9 +801,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -654,9 +654,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -685,9 +685,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -833,9 +833,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -872,12 +872,21 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample prediction use the base weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -952,12 +952,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -783,12 +783,22 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1072,12 +1072,22 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(timesteps)
|
||||
mse_loss_weights = (
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# velocity objective prediction requires SNR weights to be floored to a min value of 1.
|
||||
mse_loss_weights = mse_loss_weights + 1
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
@@ -1100,6 +1100,11 @@ def main(args):
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
# For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value.
|
||||
# When we run this, the MSE loss weights for this timestep is set unconditionally to 1.
|
||||
# If we do not run this, the loss value will go to NaN almost immediately, usually within one step.
|
||||
mse_loss_weights[snr == 0] = 1.0
|
||||
|
||||
# We first calculate the original loss. Then we mean over the non-batch dimensions and
|
||||
# rebalance the sample-wise losses with their respective loss weights.
|
||||
# Finally, we take the mean of the rebalanced loss.
|
||||
|
||||
Reference in New Issue
Block a user