From eda36c4c286d281f216dfeb79e64adad3f85d37a Mon Sep 17 00:00:00 2001 From: Leo Jiang <74156916+leisuzz@users.noreply.github.com> Date: Tue, 20 Aug 2024 09:28:50 +0800 Subject: [PATCH] Fix dtype error for StableDiffusionXL (#9217) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix dtype error Co-authored-by: 蒋硕 Co-authored-by: Sayak Paul --- examples/text_to_image/train_text_to_image_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 7f4917b546..2ca511c857 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -1084,7 +1084,7 @@ def main(args): # Add noise to the model input according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps).to(dtype=weight_dtype) # time ids def compute_time_ids(original_size, crops_coords_top_left): @@ -1101,7 +1101,7 @@ def main(args): # Predict the noise residual unet_added_conditions = {"time_ids": add_time_ids} - prompt_embeds = batch["prompt_embeds"].to(accelerator.device) + prompt_embeds = batch["prompt_embeds"].to(accelerator.device, dtype=weight_dtype) pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(accelerator.device) unet_added_conditions.update({"text_embeds": pooled_prompt_embeds}) model_pred = unet(