diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 9767adcf5d..e615fe5416 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -366,8 +366,8 @@ def main(): with context: images = pipeline(example["prompt"]).images - for image in images: - image.save(class_images_dir / f"{example['index'] + cur_class_images}.jpg") + for image, index in (images, example["index"]): + image.save(class_images_dir / f"{index + cur_class_images}.jpg") del pipeline if torch.cuda.is_available():