1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

fix mixed precision training on train_dreambooth_inpaint_lora (#3138)

cast to weight dtype
This commit is contained in:
Lucca Zenóbio
2023-04-25 06:52:57 -03:00
committed by GitHub
parent c5933c9c89
commit 0ddc5bf7b9

View File

@@ -735,7 +735,7 @@ def main():
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
for mask in masks
]
)
).to(dtype=weight_dtype)
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
# Sample noise that we'll add to the latents