From e607a582cfaa7dfaf7913fc3bb54c35eceee583c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 12 Apr 2023 06:35:06 +0530 Subject: [PATCH] [Examples] Fix type-casting issue in the ControlNet training script (#2994) * fix: norm group test for UNet3D. * fix: type-casting issue in controlnet training. --- examples/controlnet/train_controlnet.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index b1aa63b60a..3abb58b433 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -972,8 +972,10 @@ def main(args): noisy_latents, timesteps, encoder_hidden_states=encoder_hidden_states, - down_block_additional_residuals=down_block_res_samples, - mid_block_additional_residual=mid_block_res_sample, + down_block_additional_residuals=[ + sample.to(dtype=weight_dtype) for sample in down_block_res_samples + ], + mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype), ).sample # Get the target for loss depending on the prediction type