mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -23,7 +23,6 @@ from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
@@ -333,7 +332,7 @@ def get_args():
|
||||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
||||
"--adam_beta2", type=float, default=0.95, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
@@ -512,16 +511,19 @@ class VideoDataset(Dataset):
|
||||
return instance_prompts, instance_videos
|
||||
|
||||
def _preprocess_data(self):
|
||||
import decord
|
||||
try:
|
||||
import decord
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"The `decord` package is required for loading the video dataset. Install with `pip install dataset`"
|
||||
)
|
||||
|
||||
decord.bridge.set_bridge("torch")
|
||||
|
||||
videos = []
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
transforms.Lambda(lambda x: x / (255 / 2) - 1),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -532,28 +534,29 @@ class VideoDataset(Dataset):
|
||||
start_frame = min(self.skip_frames_start, video_num_frames)
|
||||
end_frame = max(0, video_num_frames - self.skip_frames_end)
|
||||
if end_frame <= start_frame:
|
||||
frames_numpy = video_reader.get_batch([start_frame]).numpy()
|
||||
frames = video_reader.get_batch([start_frame])
|
||||
elif end_frame - start_frame <= self.max_num_frames:
|
||||
frames_numpy = video_reader.get_batch(list(range(start_frame, end_frame))).numpy()
|
||||
frames = video_reader.get_batch(list(range(start_frame, end_frame)))
|
||||
else:
|
||||
indices = list(range(start_frame, end_frame, (end_frame - start_frame) // self.max_num_frames))
|
||||
frames_numpy = video_reader.get_batch(indices).numpy()
|
||||
frames = video_reader.get_batch(indices)
|
||||
|
||||
# Just to ensure that we don't go over the limit
|
||||
frames_numpy = frames_numpy[: self.max_num_frames]
|
||||
selected_num_frames = frames_numpy.shape[0]
|
||||
# Ensure that we don't go over the limit
|
||||
frames = frames[: self.max_num_frames]
|
||||
selected_num_frames = frames.shape[0]
|
||||
|
||||
# Choose first (4k + 1) frames as this is how many is required by the VAE
|
||||
remainder = (3 + (selected_num_frames % 4)) % 4
|
||||
if remainder != 0:
|
||||
frames_numpy = frames_numpy[:-remainder]
|
||||
selected_num_frames = frames_numpy.shape[0]
|
||||
frames = frames[:-remainder]
|
||||
selected_num_frames = frames.shape[0]
|
||||
|
||||
assert (selected_num_frames - 1) % 4 == 0
|
||||
|
||||
# Training transforms
|
||||
frames_tensor = torch.stack([train_transforms(frame) for frame in frames_numpy], dim=0)
|
||||
videos.append(frames_tensor) # [F, C, H, W]
|
||||
frames = frames.float()
|
||||
frames = torch.stack([train_transforms(frame) for frame in frames], dim=0)
|
||||
videos.append(frames.permute(0, 3, 1, 2).contiguous()) # [F, C, H, W]
|
||||
|
||||
return videos
|
||||
|
||||
@@ -827,6 +830,44 @@ def prepare_rotary_positional_embeddings(
|
||||
return freqs_cos, freqs_sin
|
||||
|
||||
|
||||
def get_optimizer(args, params_to_optimize):
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
return optimizer
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.report_to == "wandb" and args.hub_token is not None:
|
||||
raise ValueError(
|
||||
@@ -909,9 +950,9 @@ def main(args):
|
||||
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
text_encoder.requires_grad_(False)
|
||||
transformer.requires_grad_(False)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae, text_encoder and transformer) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full precision is not required.
|
||||
@@ -927,9 +968,9 @@ def main(args):
|
||||
"Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead."
|
||||
)
|
||||
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
transformer.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
transformer.enable_gradient_checkpointing()
|
||||
@@ -940,7 +981,7 @@ def main(args):
|
||||
transformer_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
init_lora_weights=True,
|
||||
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
|
||||
)
|
||||
transformer.add_adapter(transformer_lora_config)
|
||||
@@ -949,7 +990,7 @@ def main(args):
|
||||
text_lora_config = LoraConfig(
|
||||
r=args.rank,
|
||||
lora_alpha=args.rank,
|
||||
init_lora_weights="gaussian",
|
||||
init_lora_weights=True,
|
||||
target_modules=["q_proj", "k_proj", "v_proj", "out_proj"],
|
||||
)
|
||||
text_encoder.add_adapter(text_lora_config)
|
||||
@@ -1066,39 +1107,7 @@ def main(args):
|
||||
else:
|
||||
params_to_optimize = [transformer_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warning(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warning(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
optimizer = get_optimizer(args, params_to_optimize)
|
||||
|
||||
# Dataset and DataLoader
|
||||
train_dataset = VideoDataset(
|
||||
@@ -1175,8 +1184,10 @@ def main(args):
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
num_trainable_parameters = sum(param.numel() for model in params_to_optimize for param in model["params"])
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
@@ -1224,6 +1235,7 @@ def main(args):
|
||||
vae_scale_factor_spatial = 2 ** (len(vae.config.block_out_channels) - 1)
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
print("epoch:", epoch)
|
||||
transformer.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.train()
|
||||
@@ -1263,6 +1275,7 @@ def main(args):
|
||||
0, scheduler.config.num_train_timesteps, (batch_size,), device=model_input.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
print(model_input.shape, timesteps, prompt_embeds.shape)
|
||||
|
||||
# Prepare rotary embeds
|
||||
image_rotary_emb = (
|
||||
@@ -1278,6 +1291,7 @@ def main(args):
|
||||
if transformer.config.use_rotary_positional_embeddings
|
||||
else None
|
||||
)
|
||||
print(image_rotary_emb)
|
||||
|
||||
# Add noise to the model input according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
@@ -1292,23 +1306,26 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# =====
|
||||
# # =====
|
||||
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
|
||||
# weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
|
||||
# print(weights)
|
||||
# while len(weights.shape) < len(model_pred.shape):
|
||||
# weights = weights.unsqueeze(-1)
|
||||
# model_pred = model_pred * weights
|
||||
# target = model_input * weights
|
||||
# =====
|
||||
target = model_input
|
||||
# # =====
|
||||
|
||||
# if scheduler.config.prediction_type == "epsilon":
|
||||
# target = noise
|
||||
# elif scheduler.config.prediction_type == "v_prediction":
|
||||
# target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
# target = model_input
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
if scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif scheduler.config.prediction_type == "v_prediction":
|
||||
target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
|
||||
# loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
# loss = torch.mean((weights * (model_pred - model_input) ** 2).reshape(batch_size, -1), dim=1)
|
||||
loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1362,13 +1379,13 @@ def main(args):
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
if args.validation_prompt is not None and (epoch + 1) % args.validation_epochs == 0:
|
||||
# Create pipeline
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
transformer=unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=vae,
|
||||
vae=unwrap_model(vae),
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1380,6 +1397,8 @@ def main(args):
|
||||
"prompt": validation_prompt,
|
||||
"guidance_scale": args.guidance_scale,
|
||||
"use_dynamic_cfg": args.use_dynamic_cfg,
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
}
|
||||
|
||||
validation_outputs = log_validation(
|
||||
@@ -1428,6 +1447,8 @@ def main(args):
|
||||
"prompt": validation_prompt,
|
||||
"guidance_scale": args.guidance_scale,
|
||||
"use_dynamic_cfg": args.use_dynamic_cfg,
|
||||
"height": args.height,
|
||||
"width": args.width,
|
||||
}
|
||||
|
||||
video = log_validation(
|
||||
@@ -1463,3 +1484,35 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
|
||||
# train_dataset = VideoDataset(
|
||||
# instance_data_root=args.instance_data_root,
|
||||
# dataset_name=args.dataset_name,
|
||||
# dataset_config_name=args.dataset_config_name,
|
||||
# caption_column=args.caption_column,
|
||||
# video_column=args.video_column,
|
||||
# height=args.height,
|
||||
# width=args.width,
|
||||
# fps=args.fps,
|
||||
# max_num_frames=args.max_num_frames,
|
||||
# skip_frames_start=args.skip_frames_start,
|
||||
# skip_frames_end=args.skip_frames_end,
|
||||
# cache_dir=args.cache_dir,
|
||||
# )
|
||||
|
||||
# train_dataloader = DataLoader(
|
||||
# train_dataset,
|
||||
# batch_size=args.train_batch_size,
|
||||
# shuffle=True,
|
||||
# collate_fn=collate_fn,
|
||||
# num_workers=args.dataloader_num_workers,
|
||||
# )
|
||||
|
||||
# for batch in train_dataloader:
|
||||
# print(batch["prompts"])
|
||||
# print(batch["videos"].min(), batch["videos"].max())
|
||||
# result = CogVideoXPipeline(None, None, None, None, None).video_processor.postprocess_video(
|
||||
# batch["videos"].permute(0, 2, 1, 3, 4), output_type="pil"
|
||||
# )
|
||||
# # print(result[0])
|
||||
# export_to_video(result[0], "recon.mp4", fps=8)
|
||||
|
||||
@@ -1081,6 +1081,14 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
"""
|
||||
self.use_slicing = False
|
||||
|
||||
def _encode(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# TODO: Implement context parallel cache
|
||||
# TODO: Implement tiled encoding
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
return h
|
||||
|
||||
@apply_forward_hook
|
||||
def encode(
|
||||
self, x: torch.Tensor, return_dict: bool = True
|
||||
@@ -1097,9 +1105,12 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
||||
The latent representations of the encoded images. If `return_dict` is True, a
|
||||
[`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
|
||||
"""
|
||||
h = self.encoder(x)
|
||||
if self.quant_conv is not None:
|
||||
h = self.quant_conv(h)
|
||||
if self.use_slicing and x.shape[0] > 1:
|
||||
encoded_slices = [self._encode(x_slice) for x_slice in x.split(1)]
|
||||
h = torch.cat(encoded_slices)
|
||||
else:
|
||||
h = self._encode(x)
|
||||
|
||||
posterior = DiagonalGaussianDistribution(h)
|
||||
if not return_dict:
|
||||
return (posterior,)
|
||||
|
||||
Reference in New Issue
Block a user