diff --git a/examples/cogvideo/train_cogvideox_lora.py b/examples/cogvideo/train_cogvideox_lora.py index 9b62022494..b083368bb5 100644 --- a/examples/cogvideo/train_cogvideox_lora.py +++ b/examples/cogvideo/train_cogvideox_lora.py @@ -23,7 +23,6 @@ from pathlib import Path from typing import List, Optional, Tuple, Union import torch -import torch.nn.functional as F import transformers from accelerate import Accelerator from accelerate.logging import get_logger @@ -333,7 +332,7 @@ def get_args(): "--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers." ) parser.add_argument( - "--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers." + "--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers." ) parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params") parser.add_argument( @@ -512,16 +511,19 @@ class VideoDataset(Dataset): return instance_prompts, instance_videos def _preprocess_data(self): - import decord + try: + import decord + except ImportError: + raise ImportError( + "The `decord` package is required for loading the video dataset. Install with `pip install dataset`" + ) decord.bridge.set_bridge("torch") videos = [] - train_transforms = transforms.Compose( [ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]), + transforms.Lambda(lambda x: x / (255 / 2) - 1), ] ) @@ -532,28 +534,29 @@ class VideoDataset(Dataset): start_frame = min(self.skip_frames_start, video_num_frames) end_frame = max(0, video_num_frames - self.skip_frames_end) if end_frame <= start_frame: - frames_numpy = video_reader.get_batch([start_frame]).numpy() + frames = video_reader.get_batch([start_frame]) elif end_frame - start_frame <= self.max_num_frames: - frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).numpy() + frames = video_reader.get_batch(list(range(start_frame, end_frame))) else: indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames)) - frames_numpy = video_reader.get_batch(indices).numpy() + frames = video_reader.get_batch(indices) - # Just to ensure that we don't go over the limit - frames_numpy = frames_numpy[: self.max_num_frames] - selected_num_frames = frames_numpy.shape[0] + # Ensure that we don't go over the limit + frames = frames[: self.max_num_frames] + selected_num_frames = frames.shape[0] # Choose first (4k + 1) frames as this is how many is required by the VAE remainder = (3 + (selected_num_frames % 4)) % 4 if remainder != 0: - frames_numpy = frames_numpy[:-remainder] - selected_num_frames = frames_numpy.shape[0] + frames = frames[:-remainder] + selected_num_frames = frames.shape[0] assert (selected_num_frames - 1) % 4 == 0 # Training transforms - frames_tensor = torch.stack([train_transforms(frame) for frame in frames_numpy], dim=0) - videos.append(frames_tensor) # [F, C, H, W] + frames = frames.float() + frames = torch.stack([train_transforms(frame) for frame in frames], dim=0) + videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W] return videos @@ -827,6 +830,44 @@ def prepare_rotary_positional_embeddings( return freqs_cos, freqs_sin +def get_optimizer(args, params_to_optimize): + # Optimizer creation + if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): + logger.warning( + f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." + "Defaulting to adamW" + ) + args.optimizer = "adamw" + + if args.use_8bit_adam and not args.optimizer.lower() == "adamw": + logger.warning( + f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " + f"set to {args.optimizer.lower()}" + ) + + if args.optimizer.lower() == "adamw": + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + optimizer = optimizer_class( + params_to_optimize, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + return optimizer + + def main(args): if args.report_to == "wandb" and args.hub_token is not None: raise ValueError( @@ -909,9 +950,9 @@ def main(args): scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") # We only train the additional adapter LoRA layers + text_encoder.requires_grad_(False) transformer.requires_grad_(False) vae.requires_grad_(False) - text_encoder.requires_grad_(False) # For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision # as these weights are only used for inference, keeping weights in full precision is not required. @@ -927,9 +968,9 @@ def main(args): "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." ) - vae.to(accelerator.device, dtype=weight_dtype) - transformer.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) + transformer.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) if args.gradient_checkpointing: transformer.enable_gradient_checkpointing() @@ -940,7 +981,7 @@ def main(args): transformer_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, - init_lora_weights="gaussian", + init_lora_weights=True, target_modules=["to_k", "to_q", "to_v", "to_out.0"], ) transformer.add_adapter(transformer_lora_config) @@ -949,7 +990,7 @@ def main(args): text_lora_config = LoraConfig( r=args.rank, lora_alpha=args.rank, - init_lora_weights="gaussian", + init_lora_weights=True, target_modules=["q_proj", "k_proj", "v_proj", "out_proj"], ) text_encoder.add_adapter(text_lora_config) @@ -1066,39 +1107,7 @@ def main(args): else: params_to_optimize = [transformer_parameters_with_lr] - # Optimizer creation - if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"): - logger.warning( - f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]." - "Defaulting to adamW" - ) - args.optimizer = "adamw" - - if args.use_8bit_adam and not args.optimizer.lower() == "adamw": - logger.warning( - f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was " - f"set to {args.optimizer.lower()}" - ) - - if args.optimizer.lower() == "adamw": - if args.use_8bit_adam: - try: - import bitsandbytes as bnb - except ImportError: - raise ImportError( - "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." - ) - - optimizer_class = bnb.optim.AdamW8bit - else: - optimizer_class = torch.optim.AdamW - - optimizer = optimizer_class( - params_to_optimize, - betas=(args.adam_beta1, args.adam_beta2), - weight_decay=args.adam_weight_decay, - eps=args.adam_epsilon, - ) + optimizer = get_optimizer(args, params_to_optimize) # Dataset and DataLoader train_dataset = VideoDataset( @@ -1175,8 +1184,10 @@ def main(args): # Train! total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"]) logger.info("***** Running training *****") + logger.info(f" Num trainable parameters = {num_trainable_parameters}") logger.info(f" Num examples = {len(train_dataset)}") logger.info(f" Num batches each epoch = {len(train_dataloader)}") logger.info(f" Num Epochs = {args.num_train_epochs}") @@ -1224,6 +1235,7 @@ def main(args): vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1) for epoch in range(first_epoch, args.num_train_epochs): + print("epoch:", epoch) transformer.train() if args.train_text_encoder: text_encoder.train() @@ -1263,6 +1275,7 @@ def main(args): 0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device ) timesteps = timesteps.long() + print(model_input.shape, timesteps, prompt_embeds.shape) # Prepare rotary embeds image_rotary_emb = ( @@ -1278,6 +1291,7 @@ def main(args): if transformer.config.use_rotary_positional_embeddings else None ) + print(image_rotary_emb) # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) @@ -1292,23 +1306,26 @@ def main(args): return_dict=False, )[0] - # ===== + # # ===== # weights = 1 / (1 - scheduler.alphas_cumprod[timesteps]) + # weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps + # print(weights) # while len(weights.shape) < len(model_pred.shape): # weights = weights.unsqueeze(-1) - # model_pred = model_pred * weights - # target = model_input * weights - # ===== - target = model_input + # # ===== - # if scheduler.config.prediction_type == "epsilon": - # target = noise - # elif scheduler.config.prediction_type == "v_prediction": - # target = scheduler.get_velocity(model_input, noise, timesteps) - # else: - # raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + # target = model_input - loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + if scheduler.config.prediction_type == "epsilon": + target = noise + elif scheduler.config.prediction_type == "v_prediction": + target = scheduler.get_velocity(model_input, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}") + + # loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + # loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1) + loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1) accelerator.backward(loss) if accelerator.sync_gradients: @@ -1362,13 +1379,13 @@ def main(args): break if accelerator.is_main_process: - if args.validation_prompt is not None and epoch % args.validation_epochs == 0: + if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0: # Create pipeline pipe = CogVideoXPipeline.from_pretrained( args.pretrained_model_name_or_path, transformer=unwrap_model(transformer), text_encoder=unwrap_model(text_encoder), - vae=vae, + vae=unwrap_model(vae), revision=args.revision, variant=args.variant, torch_dtype=weight_dtype, @@ -1380,6 +1397,8 @@ def main(args): "prompt": validation_prompt, "guidance_scale": args.guidance_scale, "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, } validation_outputs = log_validation( @@ -1428,6 +1447,8 @@ def main(args): "prompt": validation_prompt, "guidance_scale": args.guidance_scale, "use_dynamic_cfg": args.use_dynamic_cfg, + "height": args.height, + "width": args.width, } video = log_validation( @@ -1463,3 +1484,35 @@ def main(args): if __name__ == "__main__": args = get_args() main(args) + + # train_dataset = VideoDataset( + # instance_data_root=args.instance_data_root, + # dataset_name=args.dataset_name, + # dataset_config_name=args.dataset_config_name, + # caption_column=args.caption_column, + # video_column=args.video_column, + # height=args.height, + # width=args.width, + # fps=args.fps, + # max_num_frames=args.max_num_frames, + # skip_frames_start=args.skip_frames_start, + # skip_frames_end=args.skip_frames_end, + # cache_dir=args.cache_dir, + # ) + + # train_dataloader = DataLoader( + # train_dataset, + # batch_size=args.train_batch_size, + # shuffle=True, + # collate_fn=collate_fn, + # num_workers=args.dataloader_num_workers, + # ) + + # for batch in train_dataloader: + # print(batch["prompts"]) + # print(batch["videos"].min(), batch["videos"].max()) + # result = CogVideoXPipeline(None, None, None, None, None).video_processor.postprocess_video( + # batch["videos"].permute(0, 2, 1, 3, 4), output_type="pil" + # ) + # # print(result[0]) + # export_to_video(result[0], "recon.mp4", fps=8) diff --git a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py index 17fa2bbf40..021a913ec8 100644 --- a/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +++ b/src/diffusers/models/autoencoders/autoencoder_kl_cogvideox.py @@ -1081,6 +1081,14 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): """ self.use_slicing = False + def _encode(self, x: torch.Tensor) -> torch.Tensor: + # TODO: Implement context parallel cache + # TODO: Implement tiled encoding + h = self.encoder(x) + if self.quant_conv is not None: + h = self.quant_conv(h) + return h + @apply_forward_hook def encode( self, x: torch.Tensor, return_dict: bool = True @@ -1097,9 +1105,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin): The latent representations of the encoded images. If `return_dict` is True, a [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned. """ - h = self.encoder(x) - if self.quant_conv is not None: - h = self.quant_conv(h) + if self.use_slicing and x.shape[0] > 1: + encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)] + h = torch.cat(encoded_slices) + else: + h = self._encode(x) + posterior = DiagonalGaussianDistribution(h) if not return_dict: return (posterior,)