From 74e6f90097e1c9e8d82220cbcfcb53b562d44284 Mon Sep 17 00:00:00 2001 From: Aryan Date: Sun, 1 Sep 2024 02:35:19 +0200 Subject: [PATCH] update --- examples/cogvideo/train_cogvideox_lora.py | 35 +++++++++++------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index b083368bb5..fa249eef16 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -1235,7 +1235,6 @@ def main(args): vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) for epoch in range(first_epoch, args.num_train_epochs): - print("epoch:", epoch) transformer.train() if args.train_text_encoder: text_encoder.train() @@ -1275,7 +1274,6 @@ def main(args): 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device ) timesteps = timesteps.long() - print(model_input.shape, timesteps, prompt_embeds.shape) # Prepare rotary embeds image_rotary_emb = ( @@ -1291,7 +1289,6 @@ def main(args): if transformer.config.use_rotary_positional_embeddings else None ) - print(image_rotary_emb) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -1306,26 +1303,26 @@ def main(args): return_dict=False, )[0] - # # ===== - # weights = 1 / (1 - scheduler.alphas_cumprod[timesteps]) - # weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps - # print(weights) - # while len(weights.shape) < len(model_pred.shape): - # weights = weights.unsqueeze(-1) - # # ===== + # ===== + weights = 1 / (1 - scheduler.alphas_cumprod[timesteps]) + print(timesteps, weights) + weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps + while len(weights.shape) < len(model_pred.shape): + weights = weights.unsqueeze(-1) + # ===== - # target = model_input + target = model_input - if scheduler.config.prediction_type == "epsilon": - target = noise - elif scheduler.config.prediction_type == "v_prediction": - target = scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + # if scheduler.config.prediction_type == "epsilon": + # target = noise + # elif scheduler.config.prediction_type == "v_prediction": + # target = scheduler.get_velocity(model_input, noise, timesteps) + # else: + # raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") - # loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1) - loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) + # loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1) accelerator.backward(loss) if accelerator.sync_gradients: