diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index b1aa63b60a..3abb58b433 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -972,8 +972,10 @@ def main(args): noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type