From 0ddc5bf7b97ee832c478ff7f4db22930b8f27d99 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lucca=20Zen=C3=B3bio?= Date: Tue, 25 Apr 2023 06:52:57 -0300 Subject: [PATCH] fix mixed precision training on train_dreambooth_inpaint_lora (#3138) cast to weight dtype --- .../dreambooth_inpaint/train_dreambooth_inpaint_lora.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py index 07df6f2011..821c66b723 100644 --- a/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py +++ b/examples/research_projects/dreambooth_inpaint/train_dreambooth_inpaint_lora.py @@ -735,7 +735,7 @@ def main(): torch.nn.functional.interpolate(mask, size=(args.resolution // 8, args.resolution // 8)) for mask in masks ] - ) + ).to(dtype=weight_dtype) mask = mask.reshape(-1, 1, args.resolution // 8, args.resolution // 8) # Sample noise that we'll add to the latents