From 8ead643bb786fe6bc80c9a4bd1730372d410a9df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andreas=20J=C3=B6rg?= <60151338+andjoer@users.noreply.github.com> Date: Fri, 14 Mar 2025 13:03:15 +0100 Subject: [PATCH] [examples/controlnet/train_controlnet_sd3.py] Fixes #11050 - Cast prompt_embeds and pooled_prompt_embeds to weight_dtype to prevent dtype mismatch (#11051) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix: dtype mismatch of prompt embeddings in sd3 controlnet training Co-authored-by: Andreas Jörg Co-authored-by: Sayak Paul --- examples/controlnet/train_controlnet_sd3.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py index f4aadc2577..ffe460d72d 100644 --- a/examples/controlnet/train_controlnet_sd3.py +++ b/examples/controlnet/train_controlnet_sd3.py @@ -1283,8 +1283,8 @@ def main(args): noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise # Get the text embedding for conditioning - prompt_embeds = batch["prompt_embeds"] - pooled_prompt_embeds = batch["pooled_prompt_embeds"] + prompt_embeds = batch["prompt_embeds"].to(dtype=weight_dtype) + pooled_prompt_embeds = batch["pooled_prompt_embeds"].to(dtype=weight_dtype) # controlnet(s) inference controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)