mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix saving
This commit is contained in:
@@ -359,17 +359,15 @@ def main():
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
all_images = []
|
||||
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
with context:
|
||||
images = pipeline(example["prompt"]).images
|
||||
all_images.extend((images, example["index"]))
|
||||
|
||||
for image, index in all_images:
|
||||
image.save(class_images_dir / f"{index + cur_class_images}.jpg")
|
||||
for image in images:
|
||||
image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg")
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
|
||||
Reference in New Issue
Block a user