mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
fix mixed precision training on train_dreambooth_inpaint_lora (#3138)
cast to weight dtype
This commit is contained in:
@@ -735,7 +735,7 @@ def main():
|
||||
torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8))
|
||||
for mask in masks
|
||||
]
|
||||
)
|
||||
).to(dtype=weight_dtype)
|
||||
mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
|
||||
Reference in New Issue
Block a user