mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -315,6 +315,18 @@ def get_args():
|
||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||
)
|
||||
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
||||
parser.add_argument(
|
||||
"--enable_slicing",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to use VAE slicing for saving memory.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable_tiling",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether or not to use VAE tiling for saving memory.",
|
||||
)
|
||||
|
||||
# Optimizer
|
||||
parser.add_argument(
|
||||
@@ -949,6 +961,11 @@ def main(args):
|
||||
|
||||
scheduler = CogVideoXDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
|
||||
if args.enable_slicing:
|
||||
vae.enable_slicing()
|
||||
if args.enable_tiling:
|
||||
vae.enable_tiling()
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
text_encoder.requires_grad_(False)
|
||||
transformer.requires_grad_(False)
|
||||
@@ -1190,10 +1207,10 @@ def main(args):
|
||||
logger.info(f" Num trainable parameters = {num_trainable_parameters}")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Num epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Gradient accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
@@ -1295,34 +1312,25 @@ def main(args):
|
||||
noisy_model_input = scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = transformer(
|
||||
model_output = transformer(
|
||||
hidden_states=noisy_model_input,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
timestep=timesteps,
|
||||
image_rotary_emb=image_rotary_emb,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
alphas_cumprod = scheduler.alphas_cumprod[timesteps]
|
||||
alphas_cumprod_sqrt = alphas_cumprod**0.5
|
||||
one_minus_alphas_cumprod_sqrt = (1 - alphas_cumprod) ** 0.5
|
||||
model_pred = noisy_model_input * alphas_cumprod_sqrt - model_output * one_minus_alphas_cumprod_sqrt
|
||||
|
||||
# =====
|
||||
weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
|
||||
print(timesteps, weights)
|
||||
weights = torch.clip(weights, min=1, max=5) # TODO: weights blows up for lower timesteps
|
||||
weights = 1 / (1 - alphas_cumprod)
|
||||
while len(weights.shape) < len(model_pred.shape):
|
||||
weights = weights.unsqueeze(-1)
|
||||
# =====
|
||||
|
||||
target = model_input
|
||||
|
||||
# if scheduler.config.prediction_type == "epsilon":
|
||||
# target = noise
|
||||
# elif scheduler.config.prediction_type == "v_prediction":
|
||||
# target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
|
||||
# loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
loss = torch.mean((weights * (model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
# loss = torch.mean(((model_pred - target) ** 2).reshape(batch_size, -1), dim=1)
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1383,6 +1391,7 @@ def main(args):
|
||||
transformer=unwrap_model(transformer),
|
||||
text_encoder=unwrap_model(text_encoder),
|
||||
vae=unwrap_model(vae),
|
||||
scheduler=scheduler,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1425,17 +1434,19 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Final test inference
|
||||
pipe = CogVideoXPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
# load attention processors
|
||||
pipe.scheduler = CogVideoXDPMScheduler.from_config(pipe.scheduler.config)
|
||||
|
||||
# Load LoRA weights
|
||||
pipe.load_lora_weights(args.output_dir)
|
||||
|
||||
# run inference
|
||||
# Run inference
|
||||
validation_outputs = []
|
||||
if args.validation_prompt and args.num_validation_videos > 0:
|
||||
validation_prompts = args.validation_prompt.split(args.validation_prompt_separator)
|
||||
@@ -1481,35 +1492,3 @@ def main(args):
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
main(args)
|
||||
|
||||
# train_dataset = VideoDataset(
|
||||
# instance_data_root=args.instance_data_root,
|
||||
# dataset_name=args.dataset_name,
|
||||
# dataset_config_name=args.dataset_config_name,
|
||||
# caption_column=args.caption_column,
|
||||
# video_column=args.video_column,
|
||||
# height=args.height,
|
||||
# width=args.width,
|
||||
# fps=args.fps,
|
||||
# max_num_frames=args.max_num_frames,
|
||||
# skip_frames_start=args.skip_frames_start,
|
||||
# skip_frames_end=args.skip_frames_end,
|
||||
# cache_dir=args.cache_dir,
|
||||
# )
|
||||
|
||||
# train_dataloader = DataLoader(
|
||||
# train_dataset,
|
||||
# batch_size=args.train_batch_size,
|
||||
# shuffle=True,
|
||||
# collate_fn=collate_fn,
|
||||
# num_workers=args.dataloader_num_workers,
|
||||
# )
|
||||
|
||||
# for batch in train_dataloader:
|
||||
# print(batch["prompts"])
|
||||
# print(batch["videos"].min(), batch["videos"].max())
|
||||
# result = CogVideoXPipeline(None, None, None, None, None).video_processor.postprocess_video(
|
||||
# batch["videos"].permute(0, 2, 1, 3, 4), output_type="pil"
|
||||
# )
|
||||
# # print(result[0])
|
||||
# export_to_video(result[0], "recon.mp4", fps=8)
|
||||
|
||||
@@ -28,7 +28,6 @@ from ...pipelines.pipeline_utils import DiffusionPipeline
|
||||
from ...schedulers import CogVideoXDDIMScheduler, CogVideoXDPMScheduler
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
BaseOutput,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
|
||||
Reference in New Issue
Block a user