From df76a39e1bc1de5bec647ce56a7fe4d8d1b6a643 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?apolin=C3=A1rio?= Date: Fri, 22 Dec 2023 06:42:04 -0600 Subject: [PATCH] Fix Prodigy optimizer in SDXL Dreambooth script (#6290) * Fix ProdigyOPT in SDXL Dreambooth script * style * style --- .../dreambooth/train_dreambooth_lora_sdxl.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 9992292e30..8a3ac294fe 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1144,10 +1144,26 @@ def main(args): optimizer_class = prodigyopt.Prodigy + if args.learning_rate <= 0.1: + logger.warn( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warn( + f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be + # --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + params_to_optimize[2]["lr"] = args.learning_rate + optimizer = optimizer_class( params_to_optimize, lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, decouple=args.prodigy_decouple,