From 074a7cc3c56a409e191fce29db0b5c414e89c19e Mon Sep 17 00:00:00 2001 From: Bagheera <59658056+bghira@users.noreply.github.com> Date: Tue, 18 Jun 2024 07:15:19 -0600 Subject: [PATCH] SD3: update default training timestep / loss weighting distribution to logit_normal (#8592) Co-authored-by: bghira Co-authored-by: Kashif Rasul --- examples/dreambooth/train_dreambooth_lora_sd3.py | 2 +- examples/dreambooth/train_dreambooth_sd3.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index e8414d24fa..21fd341726 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -473,7 +473,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] ) parser.add_argument("--logit_mean", type=float, default=0.0) parser.add_argument("--logit_std", type=float, default=1.0) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 80342e7f4d..ca41c27553 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -471,7 +471,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] ) parser.add_argument("--logit_mean", type=float, default=0.0) parser.add_argument("--logit_std", type=float, default=1.0)