diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index ff9183d780..0999a7ef7c 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import itertools import logging import math import os @@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed from diffusers.optimization import get_scheduler from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid from diffusers.training_utils import ( - _set_state_dict_into_text_encoder, cast_training_params, clear_objs_and_retain_memory, ) @@ -240,11 +238,6 @@ def get_args(): action="store_true", help="whether to randomly flip videos horizontally", ) - parser.add_argument( - "--train_text_encoder", - action="store_true", - help="Whether to train the text encoder. If set, the text encoder should be float32 precision.", - ) parser.add_argument( "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." ) @@ -297,12 +290,6 @@ def get_args(): default=1e-4, help="Initial learning rate (after the potential warmup period) to use.", ) - parser.add_argument( - "--text_encoder_lr", - type=float, - default=5e-6, - help="Text encoder learning rate to use.", - ) parser.add_argument( "--scale_lr", action="store_true", @@ -368,9 +355,6 @@ def get_args(): ) parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay") parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") - parser.add_argument( - "--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" - ) parser.add_argument( "--adam_epsilon", type=float, @@ -606,7 +590,6 @@ def save_model_card( repo_id: str, videos=None, base_model: str = None, - train_text_encoder=False, validation_prompt=None, repo_folder=None, fps=8, @@ -630,7 +613,7 @@ These are {repo_id} LoRA weights for {base_model}. The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py). -Was LoRA for the text encoder enabled? {train_text_encoder}. +Was LoRA for the text encoder enabled? No. ## Download model @@ -931,14 +914,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False): logger.warning( "Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0" ) - if args.train_text_encoder and args.text_encoder_lr: - logger.warning( - f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:" - f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " - f"When using prodigy only learning_rate is used as the initial learning rate." - ) - # Changes the learning rate of text_encoder_parameters to be --learning_rate - params_to_optimize[1]["lr"] = args.learning_rate optimizer = optimizer_class( params_to_optimize, @@ -1086,8 +1061,6 @@ def main(args): if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() - if args.train_text_encoder: - text_encoder.gradient_checkpointing_enable() # now we will add new LoRA weights to the attention layers transformer_lora_config = LoraConfig( @@ -1098,15 +1071,6 @@ def main(args): ) transformer.add_adapter(transformer_lora_config) - if args.train_text_encoder: - text_lora_config = LoraConfig( - r=args.rank, - lora_alpha=args.lora_alpha, - init_lora_weights=True, - target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], - ) - text_encoder.add_adapter(text_lora_config) - def unwrap_model(model): model = accelerator.unwrap_model(model) model = model._orig_mod if is_compiled_module(model) else model @@ -1116,13 +1080,10 @@ def main(args): def save_model_hook(models, weights, output_dir): if accelerator.is_main_process: transformer_lora_layers_to_save = None - text_encoder_lora_layers_to_save = None for model in models: if isinstance(model, type(unwrap_model(transformer))): transformer_lora_layers_to_save = get_peft_model_state_dict(model) - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_lora_layers_to_save = get_peft_model_state_dict(model) else: raise ValueError(f"unexpected save model: {model.__class__}") @@ -1132,22 +1093,18 @@ def main(args): CogVideoXPipeline.save_lora_weights( output_dir, transformer_lora_layers=transformer_lora_layers_to_save, - text_encoder_lora_layers=text_encoder_lora_layers_to_save, ) def load_model_hook(models, input_dir): transformer_ = None - text_encoder_ = None while len(models) > 0: model = models.pop() if isinstance(model, type(unwrap_model(transformer))): transformer_ = model - elif isinstance(model, type(unwrap_model(text_encoder))): - text_encoder_ = model else: - raise ValueError(f"unexpected save model: {model.__class__}") + raise ValueError(f"Unexpected save model: {model.__class__}") lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir) @@ -1164,19 +1121,13 @@ def main(args): f"Loading adapter weights from state_dict led to unexpected keys not found in the model: " f" {unexpected_keys}. " ) - if args.train_text_encoder: - # Do we need to call `scale_lora_layers()` here? - _set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_) # Make sure the trainable params are in float32. This is again needed since the base models # are in `weight_dtype`. More details: # https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804 if args.mixed_precision == "fp16": - models = [transformer_] - if args.train_text_encoder: - models.extend([text_encoder_]) # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models) + cast_training_params([transformer_]) accelerator.register_save_state_pre_hook(save_model_hook) accelerator.register_load_state_pre_hook(load_model_hook) @@ -1193,31 +1144,14 @@ def main(args): # Make sure the trainable params are in float32. if args.mixed_precision == "fp16": - models = [transformer] - if args.train_text_encoder: - models.extend([text_encoder]) # only upcast trainable parameters (LoRA) into fp32 - cast_training_params(models, dtype=torch.float32) + cast_training_params([transformer], dtype=torch.float32) transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters())) - if args.train_text_encoder: - text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters())) # Optimization parameters transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate} - if args.train_text_encoder: - # different learning rate for text encoder and unet - text_encoder_parameters_with_lr = { - "params": text_encoder_lora_parameters, - "weight_decay": args.adam_weight_decay_text_encoder, - "lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate, - } - params_to_optimize = [ - transformer_parameters_with_lr, - text_encoder_parameters_with_lr, - ] - else: - params_to_optimize = [transformer_parameters_with_lr] + params_to_optimize = [transformer_parameters_with_lr] use_deepspeed_optimizer = ( accelerator.state.deepspeed_plugin is not None @@ -1302,24 +1236,9 @@ def main(args): ) # Prepare everything with our `accelerator`. - if args.train_text_encoder: - ( - transformer, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) = accelerator.prepare( - transformer, - text_encoder, - optimizer, - train_dataloader, - lr_scheduler, - ) - else: - transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - transformer, optimizer, train_dataloader, lr_scheduler - ) + transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + transformer, optimizer, train_dataloader, lr_scheduler + ) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1391,15 +1310,9 @@ def main(args): for epoch in range(first_epoch, args.num_train_epochs): transformer.train() - if args.train_text_encoder: - text_encoder.train() - # set top parameter requires_grad = True for gradient checkpointing works - accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True) for step, batch in enumerate(train_dataloader): models_to_accumulate = [transformer] - if args.train_text_encoder: - models_to_accumulate.extend([text_encoder]) with accelerator.accumulate(models_to_accumulate): model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] @@ -1413,7 +1326,7 @@ def main(args): model_config.max_text_seq_length, accelerator.device, weight_dtype, - requires_grad=args.train_text_encoder, + requires_grad=False, ) # Sample noise that will be added to the latents @@ -1467,11 +1380,7 @@ def main(args): accelerator.backward(loss) if accelerator.sync_gradients: - params_to_clip = ( - itertools.chain(transformer.parameters(), text_encoder.parameters()) - if args.train_text_encoder - else transformer.parameters() - ) + params_to_clip = transformer.parameters() accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) if accelerator.state.deepspeed_plugin is None: @@ -1565,16 +1474,9 @@ def main(args): transformer = transformer.to(dtype) transformer_lora_layers = get_peft_model_state_dict(transformer) - if args.train_text_encoder: - text_encoder = unwrap_model(text_encoder) - text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype)) - else: - text_encoder_lora_layers = None - CogVideoXPipeline.save_lora_weights( save_directory=args.output_dir, transformer_lora_layers=transformer_lora_layers, - text_encoder_lora_layers=text_encoder_lora_layers, ) # Final test inference @@ -1624,7 +1526,6 @@ def main(args): repo_id, videos=validation_outputs, base_model=args.pretrained_model_name_or_path, - train_text_encoder=args.train_text_encoder, validation_prompt=args.validation_prompt, repo_folder=args.output_dir, fps=args.fps,