diff --git a/examples/dreambooth/train_dreambooth_lora_sd3.py b/examples/dreambooth/train_dreambooth_lora_sd3.py index e8414d24fa..21fd341726 100644 --- a/examples/dreambooth/train_dreambooth_lora_sd3.py +++ b/examples/dreambooth/train_dreambooth_lora_sd3.py @@ -473,7 +473,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] ) parser.add_argument("--logit_mean", type=float, default=0.0) parser.add_argument("--logit_std", type=float, default=1.0) diff --git a/examples/dreambooth/train_dreambooth_sd3.py b/examples/dreambooth/train_dreambooth_sd3.py index 80342e7f4d..ca41c27553 100644 --- a/examples/dreambooth/train_dreambooth_sd3.py +++ b/examples/dreambooth/train_dreambooth_sd3.py @@ -471,7 +471,7 @@ def parse_args(input_args=None): ), ) parser.add_argument( - "--weighting_scheme", type=str, default="sigma_sqrt", choices=["sigma_sqrt", "logit_normal", "mode"] + "--weighting_scheme", type=str, default="logit_normal", choices=["sigma_sqrt", "logit_normal", "mode"] ) parser.add_argument("--logit_mean", type=float, default=0.0) parser.add_argument("--logit_std", type=float, default=1.0)