From 26e80e014331410087a76cc7979ea99fb736f30a Mon Sep 17 00:00:00 2001 From: Ethan Smith <98723285+ethansmith2000@users.noreply.github.com> Date: Wed, 11 Dec 2024 20:25:59 -0800 Subject: [PATCH] fix min-snr implementation (#8466) * fix min-snr implementation https://github.com/kohya-ss/sd-scripts/blob/main/library/custom_train_functions.py#L66 * Update train_dreambooth.py fix variable name mse_loss_weights * fix divisor * make style --------- Co-authored-by: Sayak Paul Co-authored-by: YiYi Xu --- examples/dreambooth/train_dreambooth.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4b614807cf..a38146d6e9 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -1300,16 +1300,17 @@ def main(args): # Since we predict the noise instead of x_0, the original formulation is slightly changed. # This is discussed in Section 4.2 of the same paper. snr = compute_snr(noise_scheduler, timesteps) - base_weight = ( - torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr - ) if noise_scheduler.config.prediction_type == "v_prediction": # Velocity objective needs to be floored to an SNR weight of one. - mse_loss_weights = base_weight + 1 + divisor = snr + 1 else: - # Epsilon and sample both use the same loss weights. - mse_loss_weights = base_weight + divisor = snr + + mse_loss_weights = ( + torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / divisor + ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights loss = loss.mean()