diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index fbbbe35d8d..a61531a4cb 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -866,13 +866,16 @@ def prepare_rotary_positional_embeddings( return freqs_cos, freqs_sin -def get_optimizer(accelerator, args, params_to_optimize): +def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): # Use DeepSpeed optimzer - if ( - accelerator.state.deepspeed_plugin is not None - and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config - ): - return DummyOptim(params_to_optimize, lr=args.learning_rate) + if use_deepspeed: + return DummyOptim( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) # Optimizer creation supported_optimizers = ["adam", "adamw", "prodigy"] @@ -1214,7 +1217,16 @@ def main(args): else: params_to_optimize = [transformer_parameters_with_lr] - optimizer = get_optimizer(accelerator, args, params_to_optimize) + use_deepspeed_optimizer = ( + accelerator.state.deepspeed_plugin is not None + and "optimizer" in accelerator.state.deepspeed_plugin.deepspeed_config + ) + use_deepspeed_scheduler = ( + accelerator.state.deepspeed_plugin is not None + and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config + ) + + optimizer = get_optimizer(args, params_to_optimize, use_deepspeed=use_deepspeed_optimizer) # Dataset and DataLoader train_dataset = VideoDataset( @@ -1268,15 +1280,12 @@ def main(args): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - if ( - accelerator.state.deepspeed_plugin is not None - and "scheduler" not in accelerator.state.deepspeed_plugin.deepspeed_config - ): + if use_deepspeed_scheduler: lr_scheduler = DummyScheduler( name=args.lr_scheduler, optimizer=optimizer, total_num_steps=args.max_train_steps * accelerator.num_processes, - num_warmup_steps=args.lr_awrmup_steps * accelerator.num_processes, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, ) else: lr_scheduler = get_scheduler(