From f1c3c8e5a48f22ec6e98439484bf475b28abaaf9 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:01:54 +0200 Subject: [PATCH] fix saving --- examples/dreambooth/train_dreambooth.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index ed0a85f1ea..9767adcf5d 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -359,17 +359,15 @@ def main(): sample_dataloader = accelerator.prepare(sample_dataloader) pipeline.to(accelerator.device) - all_images = [] context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): with context: images = pipeline(example["prompt"]).images - all_images.extend((images, example["index"])) - for image, index in all_images: - image.save(class_images_dir / f"{index + cur_class_images}.jpg") + for image in images: + image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") del pipeline if torch.cuda.is_available():