diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index d04c616c57..34e8c69ff6 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -907,10 +907,17 @@ def main(): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) - snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + base_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - snr_loss_weights = snr_loss_weights + 1 + snr_loss_weights = base_weights + 1 + else: + # Epsilon and sample prediction use the base weights. + snr_loss_weights = base_weights + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + snr_loss_weights[snr == 0] = 1.0 + loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py index 364ed7e031..affc26101b 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_decoder.py @@ -801,9 +801,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py index 9d96a936d0..0a38c98f51 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_decoder.py @@ -654,9 +654,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py index e4aec111b8..aaa8792af0 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_lora_prior.py @@ -685,9 +685,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py index d451e1bfe4..38aa5eee8f 100644 --- a/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py +++ b/examples/kandinsky2_2/text_to_image/train_text_to_image_prior.py @@ -833,9 +833,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + + if noise_scheduler.config.prediction_type == "v_prediction": + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index b89de5e001..7e4a93dc03 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -872,12 +872,21 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) if noise_scheduler.config.prediction_type == "v_prediction": # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample prediction use the base weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 0d14e6ccd5..82201b2291 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -952,12 +952,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 5845bda0e5..0d562bc59d 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -783,12 +783,22 @@ def main(): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 7a8c2c353e..6b870a3ab5 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1072,12 +1072,22 @@ def main(args): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(timesteps) - mse_loss_weights = ( + base_weight = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": - # velocity objective prediction requires SNR weights to be floored to a min value of 1. - mse_loss_weights = mse_loss_weights + 1 + # Velocity objective needs to be floored to an SNR weight of one. + mse_loss_weights = base_weight + 1 + else: + # Epsilon and sample both use the same loss weights. + mse_loss_weights = base_weight + + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index fd37301d8f..1c579ef1fb 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1100,6 +1100,11 @@ def main(args): # Epsilon and sample both use the same loss weights. mse_loss_weights = base_weight + # For zero-terminal SNR, we have to handle the case where a sigma of Zero results in a Inf value. + # When we run this, the MSE loss weights for this timestep is set unconditionally to 1. + # If we do not run this, the loss value will go to NaN almost immediately, usually within one step. + mse_loss_weights[snr == 0] = 1.0 + # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss.