mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
remove option to train text encoder
Co-Authored-By: bghira <bghira@users.github.com>
This commit is contained in:
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -40,7 +39,6 @@ from diffusers.models.embeddings import get_3d_rotary_pos_embed
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid
|
||||
from diffusers.training_utils import (
|
||||
_set_state_dict_into_text_encoder,
|
||||
cast_training_params,
|
||||
clear_objs_and_retain_memory,
|
||||
)
|
||||
@@ -240,11 +238,6 @@ def get_args():
|
||||
action="store_true",
|
||||
help="whether to randomly flip videos horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@@ -297,12 +290,6 @@ def get_args():
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Text encoder learning rate to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
@@ -368,9 +355,6 @@ def get_args():
|
||||
)
|
||||
parser.add_argument("--prodigy_decouple", action="store_true", help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
@@ -606,7 +590,6 @@ def save_model_card(
|
||||
repo_id: str,
|
||||
videos=None,
|
||||
base_model: str = None,
|
||||
train_text_encoder=False,
|
||||
validation_prompt=None,
|
||||
repo_folder=None,
|
||||
fps=8,
|
||||
@@ -630,7 +613,7 @@ These are {repo_id} LoRA weights for {base_model}.
|
||||
|
||||
The weights were trained using the [CogVideoX Diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/cogvideo/train_cogvideox_lora.py).
|
||||
|
||||
Was LoRA for the text encoder enabled? {train_text_encoder}.
|
||||
Was LoRA for the text encoder enabled? No.
|
||||
|
||||
## Download model
|
||||
|
||||
@@ -931,14 +914,6 @@ def get_optimizer(args, params_to_optimize, use_deepspeed: bool = False):
|
||||
logger.warning(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warning(
|
||||
f"Learning rates were provided both for the transformer and the text encoder - e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# Changes the learning rate of text_encoder_parameters to be --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
@@ -1086,8 +1061,6 @@ def main(args):
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
transformer_lora_config = LoraConfig(
|
||||
@@ -1098,15 +1071,6 @@ def main(args):
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.lora_alpha,
|
||||
init_lora_weights=True,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
|
||||
def unwrap_model(model):
|
||||
model = accelerator.unwrap_model(model)
|
||||
model = model._orig_mod if is_compiled_module(model) else model
|
||||
@@ -1116,13 +1080,10 @@ def main(args):
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if accelerator.is_main_process:
|
||||
transformer_lora_layers_to_save = None
|
||||
text_encoder_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_lora_layers_to_save = get_peft_model_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
@@ -1132,22 +1093,18 @@ def main(args):
|
||||
CogVideoXPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers_to_save,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
transformer_ = None
|
||||
text_encoder_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(unwrap_model(transformer))):
|
||||
transformer_ = model
|
||||
elif isinstance(model, type(unwrap_model(text_encoder))):
|
||||
text_encoder_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
raise ValueError(f"Unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict = CogVideoXPipeline.lora_state_dict(input_dir)
|
||||
|
||||
@@ -1164,19 +1121,13 @@ def main(args):
|
||||
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
||||
f" {unexpected_keys}. "
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
# Do we need to call `scale_lora_layers()` here?
|
||||
_set_state_dict_into_text_encoder(lora_state_dict, prefix="text_encoder.", text_encoder=text_encoder_)
|
||||
|
||||
# Make sure the trainable params are in float32. This is again needed since the base models
|
||||
# are in `weight_dtype`. More details:
|
||||
# https://github.com/huggingface/diffusers/pull/6514#discussion_r1449796804
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer_]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder_])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models)
|
||||
cast_training_params([transformer_])
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
@@ -1193,31 +1144,14 @@ def main(args):
|
||||
|
||||
# Make sure the trainable params are in float32.
|
||||
if args.mixed_precision == "fp16":
|
||||
models = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models.extend([text_encoder])
|
||||
# only upcast trainable parameters (LoRA) into fp32
|
||||
cast_training_params(models, dtype=torch.float32)
|
||||
cast_training_params([transformer], dtype=torch.float32)
|
||||
|
||||
transformer_lora_parameters = list(filter(lambda p: p.requires_grad, transformer.parameters()))
|
||||
if args.train_text_encoder:
|
||||
text_encoder_lora_parameters = list(filter(lambda p: p.requires_grad, text_encoder.parameters()))
|
||||
|
||||
# Optimization parameters
|
||||
transformer_parameters_with_lr = {"params": transformer_lora_parameters, "lr": args.learning_rate}
|
||||
if args.train_text_encoder:
|
||||
# different learning rate for text encoder and unet
|
||||
text_encoder_parameters_with_lr = {
|
||||
"params": text_encoder_lora_parameters,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
transformer_parameters_with_lr,
|
||||
text_encoder_parameters_with_lr,
|
||||
]
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
use_deepspeed_optimizer = (
|
||||
accelerator.state.deepspeed_plugin is not None
|
||||
@@ -1302,24 +1236,9 @@ def main(args):
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
if args.train_text_encoder:
|
||||
(
|
||||
transformer,
|
||||
text_encoder,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
) = accelerator.prepare(
|
||||
transformer,
|
||||
text_encoder,
|
||||
optimizer,
|
||||
train_dataloader,
|
||||
lr_scheduler,
|
||||
)
|
||||
else:
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
transformer, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
transformer, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1391,15 +1310,9 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
# set top parameter requires_grad = True for gradient checkpointing works
|
||||
accelerator.unwrap_model(text_encoder).text_model.embeddings.requires_grad_(True)
|
||||
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
models_to_accumulate = [transformer]
|
||||
if args.train_text_encoder:
|
||||
models_to_accumulate.extend([text_encoder])
|
||||
|
||||
with accelerator.accumulate(models_to_accumulate):
|
||||
model_input = batch["videos"].permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
@@ -1413,7 +1326,7 @@ def main(args):
|
||||
model_config.max_text_seq_length,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
requires_grad=args.train_text_encoder,
|
||||
requires_grad=False,
|
||||
)
|
||||
|
||||
# Sample noise that will be added to the latents
|
||||
@@ -1467,11 +1380,7 @@ def main(args):
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = (
|
||||
itertools.chain(transformer.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else transformer.parameters()
|
||||
)
|
||||
params_to_clip = transformer.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
|
||||
if accelerator.state.deepspeed_plugin is None:
|
||||
@@ -1565,16 +1474,9 @@ def main(args):
|
||||
transformer = transformer.to(dtype)
|
||||
transformer_lora_layers = get_peft_model_state_dict(transformer)
|
||||
|
||||
if args.train_text_encoder:
|
||||
text_encoder = unwrap_model(text_encoder)
|
||||
text_encoder_lora_layers = get_peft_model_state_dict(text_encoder.to(dtype))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
|
||||
CogVideoXPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
transformer_lora_layers=transformer_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
# Final test inference
|
||||
@@ -1624,7 +1526,6 @@ def main(args):
|
||||
repo_id,
|
||||
videos=validation_outputs,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
fps=args.fps,
|
||||
|
||||
Reference in New Issue
Block a user