mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples] Fix type-casting issue in the ControlNet training script (#2994)
* fix: norm group test for UNet3D. * fix: type-casting issue in controlnet training.
This commit is contained in:
@@ -972,8 +972,10 @@ def main(args):
|
||||
noisy_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
down_block_additional_residuals=[
|
||||
sample.to(dtype=weight_dtype) for sample in down_block_res_samples
|
||||
],
|
||||
mid_block_additional_residual=mid_block_res_sample.to(dtype=weight_dtype),
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
|
||||
Reference in New Issue
Block a user