mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
improve sampling more
This commit is contained in:
@@ -344,7 +344,7 @@ def main():
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype
|
||||
)
|
||||
@@ -360,16 +360,16 @@ def main():
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
all_images = []
|
||||
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext
|
||||
with context:
|
||||
images = pipeline(example["prompt"]).images
|
||||
all_images.extend(images)
|
||||
all_images.extend((images, example["index"]))
|
||||
|
||||
for image, example in zip(all_images, sample_dataloader):
|
||||
image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg")
|
||||
for image, index in all_images:
|
||||
image.save(class_images_dir / f"{index + cur_class_images}.jpg")
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
|
||||
Reference in New Issue
Block a user