diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 75087284cb..d04c616c57 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -908,6 +908,9 @@ def main(): if args.snr_gamma is not None: snr = jnp.array(compute_snr(timesteps)) snr_loss_weights = jnp.where(snr < args.snr_gamma, snr, jnp.ones_like(snr) * args.snr_gamma) / snr + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + snr_loss_weights = snr_loss_weights + 1 loss = loss * snr_loss_weights loss = loss.mean() diff --git a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py index 2548c3a286..b89de5e001 100644 --- a/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py +++ b/examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py @@ -875,6 +875,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index b00884bfb7..0d14e6ccd5 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -955,6 +955,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index b9830a83ae..5845bda0e5 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -786,6 +786,9 @@ def main(): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss. diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 45ae1cc9ef..7a8c2c353e 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -1075,6 +1075,9 @@ def main(args): mse_loss_weights = ( torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr ) + if noise_scheduler.config.prediction_type == "v_prediction": + # velocity objective prediction requires SNR weights to be floored to a min value of 1. + mse_loss_weights = mse_loss_weights + 1 # We first calculate the original loss. Then we mean over the non-batch dimensions and # rebalance the sample-wise losses with their respective loss weights. # Finally, we take the mean of the rebalanced loss.