diff --git a/examples/textual_inversion/textual_inversion.py b/examples/textual_inversion/textual_inversion.py index a2e023a5ce..9089caf134 100644 --- a/examples/textual_inversion/textual_inversion.py +++ b/examples/textual_inversion/textual_inversion.py @@ -532,9 +532,15 @@ def main(): ) accelerator.register_for_checkpointing(lr_scheduler) + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + # Move vae and unet to device - vae.to(accelerator.device) - unet.to(accelerator.device) + unet.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) # Keep vae and unet in eval model as we don't train these vae.eval() @@ -600,11 +606,11 @@ def main(): with accelerator.accumulate(text_encoder): # Convert images to latent space - latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn(latents.shape).to(latents.device).to(dtype=weight_dtype) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( @@ -616,7 +622,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0].to(dtype=weight_dtype) # Predict the noise residual model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -629,7 +635,7 @@ def main(): else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - loss = F.mse_loss(model_pred, target, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) optimizer.step()