diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 7b95f21963..2959b57231 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -434,7 +434,7 @@ def main(): batch = { "instance_images": instance_images, - "input_ids": instance_prompt_ids, + "instance_prompt_ids": instance_prompt_ids, } if args.with_prior_preservation: