1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Examples] Fix type-casting issue in the ControlNet training script (#2994)

* fix: norm group test for UNet3D.

* fix: type-casting issue in controlnet training.
This commit is contained in:
Sayak Paul
2023-04-12 06:35:06 +05:30
committed by GitHub
parent ea39cd7e64
commit e607a582cf

View File

@@ -972,8 +972,10 @@ def main(args):
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
down_block_additional_residuals=[
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
],
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
).sample
# Get the target for loss depending on the prediction type