mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update
This commit is contained in:
@@ -361,10 +361,7 @@ def get_args():
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
help="Directory where logs are stored.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
@@ -573,7 +570,7 @@ def save_model_card(
|
||||
widget_dict = []
|
||||
if videos is not None:
|
||||
for i, video in enumerate(videos):
|
||||
export_to_video(video, os.path.join(repo_folder, f"video_{i}.mp4", fps=fps))
|
||||
export_to_video(video, os.path.join(repo_folder, f"final_video_{i}.mp4", fps=fps))
|
||||
widget_dict.append(
|
||||
{"text": validation_prompt if validation_prompt else " ", "output": {"url": f"video_{i}.mp4"}}
|
||||
)
|
||||
@@ -673,10 +670,25 @@ def log_validation(
|
||||
for tracker in accelerator.trackers:
|
||||
phase_name = "test" if is_final_validation else "validation"
|
||||
if tracker.name == "wandb":
|
||||
video_filenames = []
|
||||
for i, video in enumerate(videos):
|
||||
prompt = (
|
||||
pipeline_args["prompt"][:25]
|
||||
.replace(" ", "_")
|
||||
.replace(" ", "_")
|
||||
.replace("'", "_")
|
||||
.replace('"', "_")
|
||||
.replace("/", "_")
|
||||
)
|
||||
filename = os.path.join(args.output_dir, f"{phase_name}_video_{i}_{prompt}.mp4")
|
||||
export_to_video(video, filename, fps=8)
|
||||
video_filenames.append(filename)
|
||||
|
||||
tracker.log(
|
||||
{
|
||||
phase_name: [
|
||||
wandb.Video(video, caption=f"{i}: {args.validation_prompt}") for i, video in enumerate(videos)
|
||||
wandb.Video(filename, caption=f"{i}: {pipeline_args['prompt']}")
|
||||
for i, filename in enumerate(video_filenames)
|
||||
]
|
||||
}
|
||||
)
|
||||
@@ -763,6 +775,29 @@ def encode_prompt(
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def compute_prompt_embeddings(tokenizer, text_encoder, prompt, device, dtype, requires_grad: bool = False):
|
||||
if requires_grad:
|
||||
prompt_embeds = encode_prompt(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt,
|
||||
num_videos_per_prompt=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
prompt_embeds = encode_prompt(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt,
|
||||
num_videos_per_prompt=1,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
|
||||
def prepare_rotary_positional_embeddings(
|
||||
height: int,
|
||||
width: int,
|
||||
@@ -1089,20 +1124,6 @@ def main(args):
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
|
||||
def compute_text_embeddings(prompt):
|
||||
with torch.no_grad():
|
||||
prompt_embeds = encode_prompt(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompt,
|
||||
num_videos_per_prompt=1,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
)
|
||||
return prompt_embeds
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1219,34 +1240,19 @@ def main(args):
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode prompts
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = compute_text_embeddings(prompts)
|
||||
else:
|
||||
text_inputs = tokenizer(
|
||||
prompts,
|
||||
padding="max_length",
|
||||
max_length=226,
|
||||
truncation=True,
|
||||
add_special_tokens=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
prompt_embeds = encode_prompt(
|
||||
tokenizer=None,
|
||||
text_encoder=text_encoder,
|
||||
prompt=None,
|
||||
num_videos_per_prompt=1,
|
||||
device=accelerator.device,
|
||||
dtype=weight_dtype,
|
||||
text_input_ids=text_input_ids,
|
||||
)
|
||||
prompt_embeds = compute_prompt_embeddings(
|
||||
tokenizer,
|
||||
text_encoder,
|
||||
prompts,
|
||||
accelerator.device,
|
||||
weight_dtype,
|
||||
requires_grad=args.train_text_encoder,
|
||||
)
|
||||
|
||||
# Convert videos to latents
|
||||
print("videos.shape:", videos.shape)
|
||||
videos = videos.permute(0, 2, 1, 3, 4) # [B, C, F, H, W]
|
||||
model_input = vae.encode(videos).latent_dist.sample() * vae.config.scaling_factor
|
||||
model_input = model_input.permute(0, 2, 1, 3, 4).to(dtype=weight_dtype) # [B, F, C, H, W]
|
||||
print("latents.shape:", model_input.shape)
|
||||
|
||||
# Sample noise that will be added to the latents
|
||||
noise = torch.rand_like(model_input)
|
||||
@@ -1286,12 +1292,21 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
if scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif scheduler.config.prediction_type == "v_prediction":
|
||||
target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
# =====
|
||||
# weights = 1 / (1 - scheduler.alphas_cumprod[timesteps])
|
||||
# while len(weights.shape) < len(model_pred.shape):
|
||||
# weights = weights.unsqueeze(-1)
|
||||
# model_pred = model_pred * weights
|
||||
# target = model_input * weights
|
||||
# =====
|
||||
target = model_input
|
||||
|
||||
# if scheduler.config.prediction_type == "epsilon":
|
||||
# target = noise
|
||||
# elif scheduler.config.prediction_type == "v_prediction":
|
||||
# target = scheduler.get_velocity(model_input, noise, timesteps)
|
||||
# else:
|
||||
# raise ValueError(f"Unknown prediction type {scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
accelerator.backward(loss)
|
||||
@@ -1446,36 +1461,5 @@ def main(args):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Hello, world!")
|
||||
args = get_args()
|
||||
main(args)
|
||||
|
||||
# class args:
|
||||
# instance_data_root = "./z"
|
||||
# dataset_name = None
|
||||
# dataset_config_name = None
|
||||
# caption_column = "prompts.txt"
|
||||
# video_column = "videos.txt"
|
||||
# height = 480
|
||||
# width = 720
|
||||
# fps = 8
|
||||
# max_num_frames = 49
|
||||
# skip_frames_start = 0
|
||||
# skip_frames_end = 0
|
||||
# cache_dir = None
|
||||
|
||||
# # Dataset and DataLoaders creation:
|
||||
# train_dataset = VideoDataset(
|
||||
# instance_data_root=args.instance_data_root,
|
||||
# dataset_name=args.dataset_name,
|
||||
# dataset_config_name=args.dataset_config_name,
|
||||
# caption_column=args.caption_column,
|
||||
# video_column=args.video_column,
|
||||
# height=args.height,
|
||||
# width=args.width,
|
||||
# fps=args.fps,
|
||||
# max_num_frames=args.max_num_frames,
|
||||
# skip_frames_start=args.skip_frames_start,
|
||||
# skip_frames_end=args.skip_frames_end,
|
||||
# cache_dir=args.cache_dir,
|
||||
# )
|
||||
|
||||
Reference in New Issue
Block a user