From fb98acf03ba00e31cc5246915ebf2fef1a917e67 Mon Sep 17 00:00:00 2001 From: Oren WANG <371248882@qq.com> Date: Wed, 25 Jan 2023 21:56:13 +0800 Subject: [PATCH] [lora] Fix bug with training without validation (#2106) --- examples/dreambooth/train_dreambooth_lora.py | 26 ++++++++++---------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index 9d1d78beff..4c5a2bef5a 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -984,19 +984,19 @@ def main(args): prompt = args.num_validation_images * [args.validation_prompt] images = pipeline(prompt, num_inference_steps=25, generator=generator).images - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "test": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("test", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "test": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) if args.push_to_hub: save_model_card(