1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

apply amp bf16 on textual inversion (#1465)

* add conf.yaml

* enable bf16

enable amp bf16 for unet forward

fix style

fix readme

remove useless file

* change amp to full bf16

* align

* make stype

* fix format
This commit is contained in:
jiqing-feng
2022-12-16 04:15:23 +08:00
committed by GitHub
parent 61dec53356
commit c5f04d4e34

View File

@@ -532,9 +532,15 @@ def main():
)
accelerator.register_for_checkpointing(lr_scheduler)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae and unet to device
vae.to(accelerator.device)
unet.to(accelerator.device)
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
# Keep vae and unet in eval model as we don't train these
vae.eval()
@@ -600,11 +606,11 @@ def main():
with accelerator.accumulate(text_encoder):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn(latents.shape).to(latents.device)
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(
@@ -616,7 +622,7 @@ def main():
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -629,7 +635,7 @@ def main():
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
accelerator.backward(loss)
optimizer.step()