From 392fbf3c1d699edc5bf9131aa46388220fd1dcbf Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:33:02 +0200 Subject: [PATCH] fix collate fun --- examples/dreambooth/train_dreambooth.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index e5b410a70f..276a08780a 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -428,7 +428,7 @@ def main(): ).input_ids return input_ids, pixel_values - instance_prompt_ids, instance_images = _collate(example["instance_prompt_ids"], example["instance_images"]) + instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"]) batch = { "instance_images": instance_images, @@ -436,7 +436,7 @@ def main(): } if args.with_prior_preservation: - class_prompt_ids, class_images = _collate(example["class_prompt_ids"], example["class_images"]) + class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"]) batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids