mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
apply amp bf16 on textual inversion (#1465)
* add conf.yaml * enable bf16 enable amp bf16 for unet forward fix style fix readme remove useless file * change amp to full bf16 * align * make stype * fix format
This commit is contained in:
@@ -532,9 +532,15 @@ def main():
|
||||
)
|
||||
accelerator.register_for_checkpointing(lr_scheduler)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move vae and unet to device
|
||||
vae.to(accelerator.device)
|
||||
unet.to(accelerator.device)
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# Keep vae and unet in eval model as we don't train these
|
||||
vae.eval()
|
||||
@@ -600,11 +606,11 @@ def main():
|
||||
|
||||
with accelerator.accumulate(text_encoder):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -616,7 +622,7 @@ def main():
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
@@ -629,7 +635,7 @@ def main():
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean()
|
||||
accelerator.backward(loss)
|
||||
|
||||
optimizer.step()
|
||||
|
||||
Reference in New Issue
Block a user