1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

improve sampling more

This commit is contained in:
patil-suraj
2022-09-26 13:58:07 +02:00
parent 1acc6786e5
commit 627cc49447

View File

@@ -344,7 +344,7 @@ def main():
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype
)
@@ -360,16 +360,16 @@ def main():
pipeline.to(accelerator.device)
all_images = []
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
with context:
images = pipeline(example["prompt"]).images
all_images.extend(images)
all_images.extend((images, example["index"]))
for image, example in zip(all_images, sample_dataloader):
image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg")
for image, index in all_images:
image.save(class_images_dir / f"{index + cur_class_images}.jpg")
del pipeline
if torch.cuda.is_available():