1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00
This commit is contained in:
Aryan
2024-08-29 18:53:04 +02:00
parent f12e669ed3
commit 24c362ca4f

View File

@@ -361,10 +361,7 @@ def get_args():
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
help="Directory where logs are stored.",
)
parser.add_argument(
"--allow_tf32",
@@ -573,7 +570,7 @@ def save_model_card(
widget_dict = []
if videos is not None:
for i, video in enumerate(videos):
export_to_video(video, os.path.join(repo_folder, f"video_{i}.mp4", fps=fps))
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
widget_dict.append(
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
)
@@ -673,10 +670,25 @@ def log_validation(
for tracker in accelerator.trackers:
phase_name = "test" if is_final_validation else "validation"
if tracker.name == "wandb":
video_filenames = []
for i, video in enumerate(videos):
prompt = (
pipeline_args["prompt"][:25]
.replace(" ", "_")
.replace(" ", "_")
.replace("'", "_")
.replace('"', "_")
.replace("/", "_")
)
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
export_to_video(video, filename, fps=8)
video_filenames.append(filename)
tracker.log(
{
phase_name: [
wandb.Video(video, caption=f"{i}: {args.validation_prompt}") for i, video in enumerate(videos)
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
for i, filename in enumerate(video_filenames)
]
}
)
@@ -763,6 +775,29 @@ def encode_prompt(
return prompt_embeds
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
if requires_grad:
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
device=device,
dtype=dtype,
)
else:
with torch.no_grad():
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
device=device,
dtype=dtype,
)
return prompt_embeds
def prepare_rotary_positional_embeddings(
height: int,
width: int,
@@ -1089,20 +1124,6 @@ def main(args):
num_workers=args.dataloader_num_workers,
)
if not args.train_text_encoder:
def compute_text_embeddings(prompt):
with torch.no_grad():
prompt_embeds = encode_prompt(
tokenizer,
text_encoder,
prompt,
num_videos_per_prompt=1,
device=accelerator.device,
dtype=weight_dtype,
)
return prompt_embeds
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1219,34 +1240,19 @@ def main(args):
prompts = batch["prompts"]
# encode prompts
if not args.train_text_encoder:
prompt_embeds = compute_text_embeddings(prompts)
else:
text_inputs = tokenizer(
prompts,
padding="max_length",
max_length=226,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
prompt_embeds = encode_prompt(
tokenizer=None,
text_encoder=text_encoder,
prompt=None,
num_videos_per_prompt=1,
device=accelerator.device,
dtype=weight_dtype,
text_input_ids=text_input_ids,
)
prompt_embeds = compute_prompt_embeddings(
tokenizer,
text_encoder,
prompts,
accelerator.device,
weight_dtype,
requires_grad=args.train_text_encoder,
)
# Convert videos to latents
print("videos.shape:", videos.shape)
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor
model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
print("latents.shape:", model_input.shape)
# Sample noise that will be added to the latents
noise = torch.rand_like(model_input)
@@ -1286,12 +1292,21 @@ def main(args):
return_dict=False,
)[0]
if scheduler.config.prediction_type == "epsilon":
target = noise
elif scheduler.config.prediction_type == "v_prediction":
target = scheduler.get_velocity(model_input, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
# =====
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
# while len(weights.shape) < len(model_pred.shape):
# weights = weights.unsqueeze(-1)
# model_pred = model_pred * weights
# target = model_input * weights
# =====
target = model_input
# if scheduler.config.prediction_type == "epsilon":
# target = noise
# elif scheduler.config.prediction_type == "v_prediction":
# target = scheduler.get_velocity(model_input, noise, timesteps)
# else:
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
@@ -1446,36 +1461,5 @@ def main(args):
if __name__ == "__main__":
print("Hello, world!")
args = get_args()
main(args)
# class args:
# instance_data_root = "./z"
# dataset_name = None
# dataset_config_name = None
# caption_column = "prompts.txt"
# video_column = "videos.txt"
# height = 480
# width = 720
# fps = 8
# max_num_frames = 49
# skip_frames_start = 0
# skip_frames_end = 0
# cache_dir = None
# # Dataset and DataLoaders creation:
# train_dataset = VideoDataset(
# instance_data_root=args.instance_data_root,
# dataset_name=args.dataset_name,
# dataset_config_name=args.dataset_config_name,
# caption_column=args.caption_column,
# video_column=args.video_column,
# height=args.height,
# width=args.width,
# fps=args.fps,
# max_num_frames=args.max_num_frames,
# skip_frames_start=args.skip_frames_start,
# skip_frames_end=args.skip_frames_end,
# cache_dir=args.cache_dir,
# )