1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix saving

This commit is contained in:
patil-suraj
2022-09-26 14:01:54 +02:00
parent 627cc49447
commit f1c3c8e5a4

View File

@@ -359,17 +359,15 @@ def main():
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
all_images = []
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
with context:
images = pipeline(example["prompt"]).images
all_images.extend((images, example["index"]))
for image, index in all_images:
image.save(class_images_dir / f"{index + cur_class_images}.jpg")
for image in images:
image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg")
del pipeline
if torch.cuda.is_available():