From 24c362ca4fb7d8709e965a50f736b4850b20829b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 29 Aug 2024 18:53:04 +0200 Subject: [PATCH] update --- examples/cogvideo/train_cogvideox_lora.py | 144 ++++++++++------------ 1 file changed, 64 insertions(+), 80 deletions(-) diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 427b6c6f76..9b62022494 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -361,10 +361,7 @@ def get_args(): "--logging_dir", type=str, default="logs", - help=( - "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" - " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." - ), + help="Directory where logs are stored.", ) parser.add_argument( "--allow_tf32", @@ -573,7 +570,7 @@ def save_model_card( widget_dict = [] if videos is not None: for i, video in enumerate(videos): - export_to_video(video, os.path.join(repo_folder, f"video_{i}.mp4", fps=fps)) + export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps)) widget_dict.append( {"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}} ) @@ -673,10 +670,25 @@ def log_validation( for tracker in accelerator.trackers: phase_name = "test" if is_final_validation else "validation" if tracker.name == "wandb": + video_filenames = [] + for i, video in enumerate(videos): + prompt = ( + pipeline_args["prompt"][:25] + .replace(" ", "_") + .replace(" ", "_") + .replace("'", "_") + .replace('"', "_") + .replace("/", "_") + ) + filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4") + export_to_video(video, filename, fps=8) + video_filenames.append(filename) + tracker.log( { phase_name: [ - wandb.Video(video, caption=f"{i}: {args.validation_prompt}") for i, video in enumerate(videos) + wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}") + for i, filename in enumerate(video_filenames) ] } ) @@ -763,6 +775,29 @@ def encode_prompt( return prompt_embeds +def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False): + if requires_grad: + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + device=device, + dtype=dtype, + ) + else: + with torch.no_grad(): + prompt_embeds = encode_prompt( + tokenizer, + text_encoder, + prompt, + num_videos_per_prompt=1, + device=device, + dtype=dtype, + ) + return prompt_embeds + + def prepare_rotary_positional_embeddings( height: int, width: int, @@ -1089,20 +1124,6 @@ def main(args): num_workers=args.dataloader_num_workers, ) - if not args.train_text_encoder: - - def compute_text_embeddings(prompt): - with torch.no_grad(): - prompt_embeds = encode_prompt( - tokenizer, - text_encoder, - prompt, - num_videos_per_prompt=1, - device=accelerator.device, - dtype=weight_dtype, - ) - return prompt_embeds - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -1219,34 +1240,19 @@ def main(args): prompts = batch["prompts"] # encode prompts - if not args.train_text_encoder: - prompt_embeds = compute_text_embeddings(prompts) - else: - text_inputs = tokenizer( - prompts, - padding="max_length", - max_length=226, - truncation=True, - add_special_tokens=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - prompt_embeds = encode_prompt( - tokenizer=None, - text_encoder=text_encoder, - prompt=None, - num_videos_per_prompt=1, - device=accelerator.device, - dtype=weight_dtype, - text_input_ids=text_input_ids, - ) + prompt_embeds = compute_prompt_embeddings( + tokenizer, + text_encoder, + prompts, + accelerator.device, + weight_dtype, + requires_grad=args.train_text_encoder, + ) # Convert videos to latents - print("videos.shape:", videos.shape) videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W] model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W] - print("latents.shape:", model_input.shape) # Sample noise that will be added to the latents noise = torch.rand_like(model_input) @@ -1286,12 +1292,21 @@ def main(args): return_dict=False, )[0] - if scheduler.config.prediction_type == "epsilon": - target = noise - elif scheduler.config.prediction_type == "v_prediction": - target = scheduler.get_velocity(model_input, noise, timesteps) - else: - raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + # ===== + # weights = 1 / (1 - scheduler.alphas_cumprod[timesteps]) + # while len(weights.shape) < len(model_pred.shape): + # weights = weights.unsqueeze(-1) + # model_pred = model_pred * weights + # target = model_input * weights + # ===== + target = model_input + + # if scheduler.config.prediction_type == "epsilon": + # target = noise + # elif scheduler.config.prediction_type == "v_prediction": + # target = scheduler.get_velocity(model_input, noise, timesteps) + # else: + # raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") accelerator.backward(loss) @@ -1446,36 +1461,5 @@ def main(args): if __name__ == "__main__": - print("Hello, world!") args = get_args() main(args) - - # class args: - # instance_data_root = "./z" - # dataset_name = None - # dataset_config_name = None - # caption_column = "prompts.txt" - # video_column = "videos.txt" - # height = 480 - # width = 720 - # fps = 8 - # max_num_frames = 49 - # skip_frames_start = 0 - # skip_frames_end = 0 - # cache_dir = None - - # # Dataset and DataLoaders creation: - # train_dataset = VideoDataset( - # instance_data_root=args.instance_data_root, - # dataset_name=args.dataset_name, - # dataset_config_name=args.dataset_config_name, - # caption_column=args.caption_column, - # video_column=args.video_column, - # height=args.height, - # width=args.width, - # fps=args.fps, - # max_num_frames=args.max_num_frames, - # skip_frames_start=args.skip_frames_start, - # skip_frames_end=args.skip_frames_end, - # cache_dir=args.cache_dir, - # )