From ca980fd0d1003407dbab7300ce15b0af2ded3bd9 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Thu, 16 Feb 2023 22:27:47 +0200 Subject: [PATCH] [Examples] Make sure EMA works with any device (#2382) * Fix EMA * up * update --- examples/text_to_image/train_text_to_image.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 879e788773..48a37c1b75 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -438,6 +438,7 @@ def main(): if args.use_ema: load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) del load_model for i in range(len(models)):