diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e5b410a70f..276a08780a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -428,7 +428,7 @@ def main(): ).input_ids return input_ids, pixel_values - instance_prompt_ids, instance_images = _collate(example["instance_prompt_ids"], example["instance_images"]) + instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"]) batch = { "instance_images": instance_images, @@ -436,7 +436,7 @@ def main(): } if args.with_prior_preservation: - class_prompt_ids, class_images = _collate(example["class_prompt_ids"], example["class_images"]) + class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"]) batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids