diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index eccf9f1425..ed0a85f1ea 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -344,7 +344,7 @@ def main(): cur_class_images = len(list(class_images_dir.iterdir())) if cur_class_images < args.num_class_images: - torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 + torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32 pipeline = StableDiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token, torch_dtype=torch_dtype ) @@ -360,16 +360,16 @@ def main(): pipeline.to(accelerator.device) all_images = [] + context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext for example in tqdm( sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process ): - context = torch.autocast("cuda") if accelerator.device.type == "cuda" else nullcontext with context: images = pipeline(example["prompt"]).images - all_images.extend(images) + all_images.extend((images, example["index"])) - for image, example in zip(all_images, sample_dataloader): - image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") + for image, index in all_images: + image.save(class_images_dir / f"{index + cur_class_images}.jpg") del pipeline if torch.cuda.is_available():