1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

do two forward passes for prior preservation

This commit is contained in:
patil-suraj
2022-09-26 22:41:47 +02:00
parent 87bc75231a
commit 661ca4677e

View File

@@ -442,20 +442,25 @@ def main():
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# concat class and instance examples for prior preservation
if args.with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
pixel_values = torch.stack(pixel_values).to(memory_format=torch.contiguous_format).float()
input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
if args.with_prior_preservation:
class_input_ids = [example["class_prompt_ids"] for example in examples]
class_pixel_values = [example["class_images"] for example in examples]
class_pixel_values = torch.stack(class_pixel_values).to(memory_format=torch.contiguous_format).float()
class_input_ids = tokenizer.pad(
{"input_ids": class_input_ids}, padding=True, return_tensors="pt"
).input_ids
batch["class_input_ids"] = class_input_ids
batch["class_pixel_values"] = class_pixel_values
return batch
train_dataloader = torch.utils.data.DataLoader(
@@ -516,33 +521,41 @@ def main():
unet.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
def _forward(input_ids, pixel_values):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(pixel_values).latent_dist.sample()
latents = latents * 0.18215
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
).long()
# Get the text embedding for conditioning
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the text embedding for conditioning
with torch.no_grad():
encoder_hidden_states = text_encoder(input_ids)[0]
# Predict the noise residual
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
return loss
loss = _forward(batch["input_ids"], batch["pixel_values"])
if args.with_prior_preservation:
prior_loss = _forward(batch["class_input_ids"], batch["class_pixel_values"])
loss = loss + prior_loss
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()