diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 879e788773..48a37c1b75 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -438,6 +438,7 @@ def main(): if args.use_ema: load_model = EMAModel.from_pretrained(os.path.join(input_dir, "unet_ema"), UNet2DConditionModel) ema_unet.load_state_dict(load_model.state_dict()) + ema_unet.to(accelerator.device) del load_model for i in range(len(models)):