diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index 2e902db7ff..eaeb697c64 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -1048,7 +1048,9 @@ def main(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # Get the text embedding for conditioning encoder_hidden_states = text_encoder(batch["input_ids"], return_dict=False)[0] diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index 877ca61358..ae627bb3a0 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -1210,7 +1210,9 @@ def main(args): # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents.float(), noise.float(), timesteps).to( + dtype=weight_dtype + ) # ControlNet conditioning. controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype)