From 63a0c9e5f7a56d49dc142e643b7237fc9082ff59 Mon Sep 17 00:00:00 2001 From: "G.O.D" <32255912+gameofdimension@users.noreply.github.com> Date: Tue, 22 Oct 2024 07:26:05 +0800 Subject: [PATCH] [bugfix] reduce float value error when adding noise (#9004) * Update train_controlnet.py reduce float value error for bfloat16 * Update train_controlnet_sdxl.py * style --------- Co-authored-by: Sayak Paul Co-authored-by: yiyixuxu --- examples/controlnet/train_controlnet.py | 4 +++- examples/controlnet/train_controlnet_sdxl.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2e902db7ff..eaeb697c64 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -1048,7 +1048,9 @@ def main(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 877ca61358..ae627bb3a0 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1210,7 +1210,9 @@ def main(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)