diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 295ec4120e..8de8bf953a 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -1311,14 +1311,13 @@ def main(args): pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config, **scheduler_args) - pipeline = pipeline.to(accelerator.device) - # load attention processors pipeline.load_lora_weights(args.output_dir) # run inference images = [] if args.validation_prompt and args.num_validation_images > 0: + pipeline = pipeline.to(accelerator.device) generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None images = [ pipeline(args.validation_prompt, num_inference_steps=25, generator=generator).images[0]