diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e615fe5416..0c7deadb33 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -366,8 +366,8 @@ def main(): with context: images = pipeline(example["prompt"]).images - for image, index in (images, example["index"]): - image.save(class_images_dir / f"{index + cur_class_images}.jpg") + for i, image in enumerate(images): + image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") del pipeline if torch.cuda.is_available():