1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2024-08-31 15:41:26 +02:00
parent 24c362ca4f
commit 588c6ee602
2 changed files with 135 additions and 71 deletions

View File

@@ -23,7 +23,6 @@ from pathlib import Path
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
@@ -333,7 +332,7 @@ def get_args():
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument(
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
"--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
)
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
parser.add_argument(
@@ -512,16 +511,19 @@ class VideoDataset(Dataset):
return instance_prompts, instance_videos
def _preprocess_data(self):
import decord
try:
import decord
except ImportError:
raise ImportError(
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
)
decord.bridge.set_bridge("torch")
videos = []
train_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
transforms.Lambda(lambda x: x / (255 / 2) - 1),
]
)
@@ -532,28 +534,29 @@ class VideoDataset(Dataset):
start_frame = min(self.skip_frames_start, video_num_frames)
end_frame = max(0, video_num_frames - self.skip_frames_end)
if end_frame <= start_frame:
frames_numpy = video_reader.get_batch([start_frame]).numpy()
frames = video_reader.get_batch([start_frame])
elif end_frame - start_frame <= self.max_num_frames:
frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).numpy()
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
else:
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
frames_numpy = video_reader.get_batch(indices).numpy()
frames = video_reader.get_batch(indices)
# Just to ensure that we don't go over the limit
frames_numpy = frames_numpy[: self.max_num_frames]
selected_num_frames = frames_numpy.shape[0]
# Ensure that we don't go over the limit
frames = frames[: self.max_num_frames]
selected_num_frames = frames.shape[0]
# Choose first (4k + 1) frames as this is how many is required by the VAE
remainder = (3 + (selected_num_frames % 4)) % 4
if remainder != 0:
frames_numpy = frames_numpy[:-remainder]
selected_num_frames = frames_numpy.shape[0]
frames = frames[:-remainder]
selected_num_frames = frames.shape[0]
assert (selected_num_frames - 1) % 4 == 0
# Training transforms
frames_tensor = torch.stack([train_transforms(frame) for frame in frames_numpy], dim=0)
videos.append(frames_tensor) # [F, C, H, W]
frames = frames.float()
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
return videos
@@ -827,6 +830,44 @@ def prepare_rotary_positional_embeddings(
return freqs_cos, freqs_sin
def get_optimizer(args, params_to_optimize):
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
return optimizer
def main(args):
if args.report_to == "wandb" and args.hub_token is not None:
raise ValueError(
@@ -909,9 +950,9 @@ def main(args):
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
# We only train the additional adapter LoRA layers
text_encoder.requires_grad_(False)
transformer.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
@@ -927,9 +968,9 @@ def main(args):
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
)
vae.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
transformer.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
if args.gradient_checkpointing:
transformer.enable_gradient_checkpointing()
@@ -940,7 +981,7 @@ def main(args):
transformer_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
init_lora_weights=True,
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
transformer.add_adapter(transformer_lora_config)
@@ -949,7 +990,7 @@ def main(args):
text_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
init_lora_weights=True,
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
)
text_encoder.add_adapter(text_lora_config)
@@ -1066,39 +1107,7 @@ def main(args):
else:
params_to_optimize = [transformer_parameters_with_lr]
# Optimizer creation
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
logger.warning(
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
"Defaulting to adamW"
)
args.optimizer = "adamw"
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
logger.warning(
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
f"set to {args.optimizer.lower()}"
)
if args.optimizer.lower() == "adamw":
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
params_to_optimize,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
optimizer = get_optimizer(args, params_to_optimize)
# Dataset and DataLoader
train_dataset = VideoDataset(
@@ -1175,8 +1184,10 @@ def main(args):
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
logger.info("***** Running training *****")
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
@@ -1224,6 +1235,7 @@ def main(args):
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
for epoch in range(first_epoch, args.num_train_epochs):
print("epoch:", epoch)
transformer.train()
if args.train_text_encoder:
text_encoder.train()
@@ -1263,6 +1275,7 @@ def main(args):
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
)
timesteps = timesteps.long()
print(model_input.shape, timesteps, prompt_embeds.shape)
# Prepare rotary embeds
image_rotary_emb = (
@@ -1278,6 +1291,7 @@ def main(args):
if transformer.config.use_rotary_positional_embeddings
else None
)
print(image_rotary_emb)
# Add noise to the model input according to the noise magnitude at each timestep
# (this is the forward diffusion process)
@@ -1292,23 +1306,26 @@ def main(args):
return_dict=False,
)[0]
# =====
# # =====
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
# weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
# print(weights)
# while len(weights.shape) < len(model_pred.shape):
# weights = weights.unsqueeze(-1)
# model_pred = model_pred * weights
# target = model_input * weights
# =====
target = model_input
# # =====
# if scheduler.config.prediction_type == "epsilon":
# target = noise
# elif scheduler.config.prediction_type == "v_prediction":
# target = scheduler.get_velocity(model_input, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
# target = model_input
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
if scheduler.config.prediction_type == "epsilon":
target = noise
elif scheduler.config.prediction_type == "v_prediction":
target = scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
# loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1)
loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
accelerator.backward(loss)
if accelerator.sync_gradients:
@@ -1362,13 +1379,13 @@ def main(args):
break
if accelerator.is_main_process:
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
# Create pipeline
pipe = CogVideoXPipeline.from_pretrained(
args.pretrained_model_name_or_path,
transformer=unwrap_model(transformer),
text_encoder=unwrap_model(text_encoder),
vae=vae,
vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
@@ -1380,6 +1397,8 @@ def main(args):
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
}
validation_outputs = log_validation(
@@ -1428,6 +1447,8 @@ def main(args):
"prompt": validation_prompt,
"guidance_scale": args.guidance_scale,
"use_dynamic_cfg": args.use_dynamic_cfg,
"height": args.height,
"width": args.width,
}
video = log_validation(
@@ -1463,3 +1484,35 @@ def main(args):
if __name__ == "__main__":
args = get_args()
main(args)
# train_dataset = VideoDataset(
# instance_data_root=args.instance_data_root,
# dataset_name=args.dataset_name,
# dataset_config_name=args.dataset_config_name,
# caption_column=args.caption_column,
# video_column=args.video_column,
# height=args.height,
# width=args.width,
# fps=args.fps,
# max_num_frames=args.max_num_frames,
# skip_frames_start=args.skip_frames_start,
# skip_frames_end=args.skip_frames_end,
# cache_dir=args.cache_dir,
# )
# train_dataloader = DataLoader(
# train_dataset,
# batch_size=args.train_batch_size,
# shuffle=True,
# collate_fn=collate_fn,
# num_workers=args.dataloader_num_workers,
# )
# for batch in train_dataloader:
# print(batch["prompts"])
# print(batch["videos"].min(), batch["videos"].max())
# result = CogVideoXPipeline(None, None, None, None, None).video_processor.postprocess_video(
# batch["videos"].permute(0, 2, 1, 3, 4), output_type="pil"
# )
# # print(result[0])
# export_to_video(result[0], "recon.mp4", fps=8)

View File

@@ -1081,6 +1081,14 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
"""
self.use_slicing = False
def _encode(self, x: torch.Tensor) -> torch.Tensor:
# TODO: Implement context parallel cache
# TODO: Implement tiled encoding
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
return h
@apply_forward_hook
def encode(
self, x: torch.Tensor, return_dict: bool = True
@@ -1097,9 +1105,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
The latent representations of the encoded images. If `return_dict` is True, a
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
"""
h = self.encoder(x)
if self.quant_conv is not None:
h = self.quant_conv(h)
if self.use_slicing and x.shape[0] > 1:
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
h = torch.cat(encoded_slices)
else:
h = self._encode(x)
posterior = DiagonalGaussianDistribution(h)
if not return_dict:
return (posterior,)