mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[training] use the lr when using 8bit adam. (#9796)
* use the lr when using 8bit adam. * remove lr as we pack it in params_to_optimize. --------- Co-authored-by: Linoy Tsaban <57615435+linoytsaban@users.noreply.github.com>
This commit is contained in:
@@ -1778,15 +1778,10 @@ def main(args):
|
||||
if not args.enable_t5_ti:
|
||||
# pure textual inversion - only clip
|
||||
if pure_textual_inversion:
|
||||
params_to_optimize = [
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
params_to_optimize = [text_parameters_one_with_lr]
|
||||
te_idx = 0
|
||||
else: # regular te training or regular pivotal for clip
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
te_idx = 1
|
||||
elif args.enable_t5_ti:
|
||||
# pivotal tuning of clip & t5
|
||||
@@ -1809,9 +1804,7 @@ def main(args):
|
||||
]
|
||||
te_idx = 1
|
||||
else:
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
]
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
@@ -1871,7 +1864,6 @@ def main(args):
|
||||
params_to_optimize[-1]["lr"] = args.learning_rate
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1358,10 +1358,7 @@ def main(args):
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
unet_lora_parameters_with_lr,
|
||||
text_lora_parameters_one_with_lr,
|
||||
]
|
||||
params_to_optimize = [unet_lora_parameters_with_lr, text_lora_parameters_one_with_lr]
|
||||
else:
|
||||
params_to_optimize = [unet_lora_parameters_with_lr]
|
||||
|
||||
@@ -1423,7 +1420,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1794,7 +1794,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -947,7 +947,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -969,7 +969,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1226,10 +1226,7 @@ def main(args):
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
@@ -1291,7 +1288,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1335,10 +1335,7 @@ def main(args):
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_parameters_one_with_lr,
|
||||
]
|
||||
params_to_optimize = [transformer_parameters_with_lr, text_parameters_one_with_lr]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
@@ -1400,7 +1397,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1468,7 +1468,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1402,7 +1402,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1328,7 +1328,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
@@ -1475,7 +1475,6 @@ def main(args):
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
|
||||
Reference in New Issue
Block a user