From 418331094deb3bd4407d86ef0176f074b892579f Mon Sep 17 00:00:00 2001 From: Yuta Hayashibe Date: Tue, 24 Jan 2023 22:19:22 +0900 Subject: [PATCH] Run inference on a specific condition and fix call of manual_seed() (#2074) --- examples/dreambooth/train_dreambooth_lora.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index fb9f2c832e..d0b7b9573b 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -972,9 +972,10 @@ def main(args): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - prompt = args.num_validation_images * [args.validation_prompt] - images = pipeline(prompt, num_inference_steps=25, generator=generator).images + if args.validation_prompt and args.num_validation_images > 0: + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None + prompt = args.num_validation_images * [args.validation_prompt] + images = pipeline(prompt, num_inference_steps=25, generator=generator).images for tracker in accelerator.trackers: if tracker.name == "tensorboard":