From 661ca4677e6dccc4ad596c2ee6ca4baad4159e95 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 22:41:47 +0200 Subject: [PATCH] do two forward passes for prior preservation --- examples/dreambooth/train_dreambooth.py | 71 +++++++++++++++---------- 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 1829929bbf..2f4bab89f5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -442,20 +442,25 @@ def main(): input_ids = [example["instance_prompt_ids"] for example in examples] pixel_values = [example["instance_images"] for example in examples] - # concat class and instance examples for prior preservation - if args.with_prior_preservation: - input_ids += [example["class_prompt_ids"] for example in examples] - pixel_values += [example["class_images"] for example in examples] - - pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - + pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float() input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids batch = { "input_ids": input_ids, "pixel_values": pixel_values, } + + if args.with_prior_preservation: + class_input_ids = [example["class_prompt_ids"] for example in examples] + class_pixel_values = [example["class_images"] for example in examples] + + class_pixel_values = torch.stack(class_pixel_values).to(memory_format=torch.contiguous_format).float() + class_input_ids = tokenizer.pad( + {"input_ids": class_input_ids}, padding=True, return_tensors="pt" + ).input_ids + batch["class_input_ids"] = class_input_ids + batch["class_pixel_values"] = class_pixel_values + return batch train_dataloader = torch.utils.data.DataLoader( @@ -516,33 +521,41 @@ def main(): unet.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): - # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 - # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint( - 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device - ).long() + def _forward(input_ids, pixel_values): + # Convert images to latent space + with torch.no_grad(): + latents = vae.encode(pixel_values).latent_dist.sample() + latents = latents * 0.18215 - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() - # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Predict the noise residual - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Get the text embedding for conditioning + with torch.no_grad(): + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + return loss + + loss = _forward(batch["input_ids"], batch["pixel_values"]) + + if args.with_prior_preservation: + prior_loss = _forward(batch["class_input_ids"], batch["class_pixel_values"]) + loss = loss + prior_loss - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) - optimizer.step() lr_scheduler.step() optimizer.zero_grad()