From 80c87718add1dcbf856f658d3eaa17db6bd7741f Mon Sep 17 00:00:00 2001 From: Aryan Date: Sat, 14 Sep 2024 02:36:07 +0200 Subject: [PATCH] apply suggestions from review; prodigy optimizer YiYi Xu --- examples/cogvideo/train_cogvideox_lora.py | 91 ++++++++++++++++++++--- 1 file changed, 81 insertions(+), 10 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 3c9d313425..59d7124b1d 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -360,6 +360,13 @@ def get_args(): parser.add_argument( "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." ) + parser.add_argument( + "--prodigy_beta3", + type=float, + default=None, + help="Coefficients for computing the Prodigy optimizer's stepsize using running averages. If set to None, uses the value of square root of beta2.", + ) + parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" @@ -371,6 +378,15 @@ def get_args(): help="Epsilon value for the Adam optimizer and Prodigy optimizers.", ) parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument( + "--prodigy_use_bias_correction", type=bool, default=True, help="Turn on Adam's bias correction." + ) + parser.add_argument( + "--prodigy_safeguard_warmup", + type=bool, + default=True, + help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage.", + ) # Other information parser.add_argument("--tracker_name", type=str, default=None, help="Project tracker name") @@ -851,9 +867,9 @@ def get_optimizer(args, params_to_optimize): ) args.optimizer = "adamw" - if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + if args.use_8bit_adam and not (args.optimizer.lower() not in ["adam", "adamw"]): logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"use_8bit_adam is ignored when optimizer is not set to 'Adam' or 'AdamW'. Optimizer was " f"set to {args.optimizer.lower()}" ) @@ -883,6 +899,38 @@ def get_optimizer(args, params_to_optimize): eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) + elif args.optimizer.lower() == "prodigy": + try: + import prodigyopt + except ImportError: + raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`") + + optimizer_class = prodigyopt.Prodigy + + if args.learning_rate <= 0.1: + logger.warning( + "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" + ) + if args.train_text_encoder and args.text_encoder_lr: + logger.warning( + f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:" + f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " + f"When using prodigy only learning_rate is used as the initial learning rate." + ) + # Changes the learning rate of text_encoder_parameters to be --learning_rate + params_to_optimize[1]["lr"] = args.learning_rate + + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + beta3=args.prodigy_beta3, + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + decouple=args.prodigy_decouple, + use_bias_correction=args.prodigy_use_bias_correction, + safeguard_warmup=args.prodigy_safeguard_warmup, + ) return optimizer @@ -958,8 +1006,15 @@ def main(args): args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision ) + # CogVideoX-2b weights are stored in float16 + # CogVideoX-5b and CogVideoX-5b-I2V weights are stored in bfloat16 + load_dtype = torch.bfloat16 if "5b" in args.pretrained_model_name_or_path.lower() else torch.float16 transformer = CogVideoXTransformer3DModel.from_pretrained( - args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + args.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=load_dtype, + revision=args.revision, + variant=args.variant, ) vae = AutoencoderKLCogVideoX.from_pretrained( @@ -981,10 +1036,23 @@ def main(args): # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. weight_dtype = torch.float32 - if accelerator.mixed_precision == "fp16": - weight_dtype = torch.float16 - elif accelerator.mixed_precision == "bf16": - weight_dtype = torch.bfloat16 + if accelerator.state.deepspeed_plugin: + # DeepSpeed is handling precision, use what's in the DeepSpeed config + if ( + "fp16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["fp16"]["enabled"] + ): + weight_dtype = torch.float16 + if ( + "bf16" in accelerator.state.deepspeed_plugin.deepspeed_config + and accelerator.state.deepspeed_plugin.deepspeed_config["bf16"]["enabled"] + ): + weight_dtype = torch.float16 + else: + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 if torch.backends.mps.is_available() and weight_dtype == torch.bfloat16: # due to pytorch#99272, MPS does not yet support bfloat16. @@ -1279,6 +1347,9 @@ def main(args): ) vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) + # For DeepSpeed training + model_config = transformer.module.config if hasattr(transformer, "module") else transformer.config + for epoch in range(first_epoch, args.num_train_epochs): transformer.train() if args.train_text_encoder: @@ -1322,11 +1393,11 @@ def main(args): width=args.width, num_frames=num_frames, vae_scale_factor_spatial=vae_scale_factor_spatial, - patch_size=transformer.config.patch_size, - attention_head_dim=transformer.config.attention_head_dim, + patch_size=model_config.patch_size, + attention_head_dim=model_config.attention_head_dim, device=accelerator.device, ) - if transformer.config.use_rotary_positional_embeddings + if model_config.use_rotary_positional_embeddings else None )