From 9a95d8de56bbb21eee9dab5ee7d089b4dbcdf9f1 Mon Sep 17 00:00:00 2001 From: Aryan Date: Tue, 3 Sep 2024 05:57:48 +0200 Subject: [PATCH] update --- examples/cogvideo/train_cogvideox_lora.py | 83 +++++++------------ .../pipelines/cogvideo/pipeline_cogvideox.py | 1 - 2 files changed, 31 insertions(+), 53 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index fa249eef16..3bc9694153 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -315,6 +315,18 @@ def get_args(): help="Number of hard resets of the lr in cosine_with_restarts scheduler.", ) parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--enable_slicing", + action="store_true", + default=False, + help="Whether or not to use VAE slicing for saving memory.", + ) + parser.add_argument( + "--enable_tiling", + action="store_true", + default=False, + help="Whether or not to use VAE tiling for saving memory.", + ) # Optimizer parser.add_argument( @@ -949,6 +961,11 @@ def main(args): scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + if args.enable_slicing: + vae.enable_slicing() + if args.enable_tiling: + vae.enable_tiling() + # We only train the additional adapter LoRA layers text_encoder.requires_grad_(False) transformer.requires_grad_(False) @@ -1190,10 +1207,10 @@ def main(args): logger.info(f" Num trainable parameters = {num_trainable_parameters}") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") - logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Num epochs = {args.num_train_epochs}") logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") - logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {args.max_train_steps}") global_step = 0 first_epoch = 0 @@ -1295,34 +1312,25 @@ def main(args): noisy_model_input = scheduler.add_noise(model_input, noise, timesteps) # Predict the noise residual - model_pred = transformer( + model_output = transformer( hidden_states=noisy_model_input, encoder_hidden_states=prompt_embeds, timestep=timesteps, image_rotary_emb=image_rotary_emb, return_dict=False, )[0] + alphas_cumprod = scheduler.alphas_cumprod[timesteps] + alphas_cumprod_sqrt = alphas_cumprod**0.5 + one_minus_alphas_cumprod_sqrt = (1 - alphas_cumprod) ** 0.5 + model_pred = noisy_model_input * alphas_cumprod_sqrt - model_output * one_minus_alphas_cumprod_sqrt - # ===== - weights = 1 / (1 - scheduler.alphas_cumprod[timesteps]) - print(timesteps, weights) - weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps + weights = 1 / (1 - alphas_cumprod) while len(weights.shape) < len(model_pred.shape): weights = weights.unsqueeze(-1) - # ===== target = model_input - # if scheduler.config.prediction_type == "epsilon": - # target = noise - # elif scheduler.config.prediction_type == "v_prediction": - # target = scheduler.get_velocity(model_input, noise, timesteps) - # else: - # raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") - - # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1) - # loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1) accelerator.backward(loss) if accelerator.sync_gradients: @@ -1383,6 +1391,7 @@ def main(args): transformer=unwrap_model(transformer), text_encoder=unwrap_model(text_encoder), vae=unwrap_model(vae), + scheduler=scheduler, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1425,17 +1434,19 @@ def main(args): text_encoder_lora_layers=text_encoder_lora_layers, ) - # Final inference + # Final test inference pipe = CogVideoXPipeline.from_pretrained( args.pretrained_model_name_or_path, revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, ) - # load attention processors + pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config) + + # Load LoRA weights pipe.load_lora_weights(args.output_dir) - # run inference + # Run inference validation_outputs = [] if args.validation_prompt and args.num_validation_videos > 0: validation_prompts = args.validation_prompt.split(args.validation_prompt_separator) @@ -1481,35 +1492,3 @@ def main(args): if __name__ == "__main__": args = get_args() main(args) - - # train_dataset = VideoDataset( - # instance_data_root=args.instance_data_root, - # dataset_name=args.dataset_name, - # dataset_config_name=args.dataset_config_name, - # caption_column=args.caption_column, - # video_column=args.video_column, - # height=args.height, - # width=args.width, - # fps=args.fps, - # max_num_frames=args.max_num_frames, - # skip_frames_start=args.skip_frames_start, - # skip_frames_end=args.skip_frames_end, - # cache_dir=args.cache_dir, - # ) - - # train_dataloader = DataLoader( - # train_dataset, - # batch_size=args.train_batch_size, - # shuffle=True, - # collate_fn=collate_fn, - # num_workers=args.dataloader_num_workers, - # ) - - # for batch in train_dataloader: - # print(batch["prompts"]) - # print(batch["videos"].min(), batch["videos"].max()) - # result = CogVideoXPipeline(None, None, None, None, None).video_processor.postprocess_video( - # batch["videos"].permute(0, 2, 1, 3, 4), output_type="pil" - # ) - # # print(result[0]) - # export_to_video(result[0], "recon.mp4", fps=8) diff --git a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py index 4944bb0076..be83eef3bc 100644 --- a/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py +++ b/src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py @@ -28,7 +28,6 @@ from ...pipelines.pipeline_utils import DiffusionPipeline from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler from ...utils import ( USE_PEFT_BACKEND, - BaseOutput, logging, replace_example_docstring, scale_lora_layers,