mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Correct SNR weighted loss in v-prediction case by only adding 1 to SNR on the denominator (#6307)
* fix minsnr implementation for v-prediction case * format code * always compute snr when snr_gamma is specified --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -907,10 +907,12 @@ def main():
|
||||
|
||||
if args.snr_gamma is not None:
|
||||
snr = jnp.array(compute_snr(timesteps))
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr
|
||||
snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma)
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
snr_loss_weights = snr_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
snr_loss_weights = snr_loss_weights / (snr + 1)
|
||||
|
||||
loss = loss * snr_loss_weights
|
||||
|
||||
loss = loss.mean()
|
||||
|
||||
@@ -781,12 +781,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -631,12 +631,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -664,12 +664,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -811,12 +811,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -848,12 +848,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -943,12 +943,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -759,12 +759,13 @@ def main():
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -1062,12 +1062,13 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
@@ -1087,12 +1087,13 @@ def main(args):
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective requires that we add one to SNR values before we divide by them.
|
||||
snr = snr + 1
|
||||
mse_loss_weights = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
mse_loss_weights = torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
|
||||
Reference in New Issue
Block a user