From d6c88f4ef67eb17308ef2efe6954fa1286d8b891 Mon Sep 17 00:00:00 2001 From: patil-suraj Date: Mon, 26 Sep 2022 14:35:40 +0200 Subject: [PATCH] fix collate_fn --- examples/dreambooth/train_dreambooth.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 276a08780a..197acea441 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -428,7 +428,9 @@ def main(): ).input_ids return input_ids, pixel_values - instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"]) + instance_prompt_ids = [example["instance_prompt_ids"] for example in examples] + instance_images = [example["instance_images"] for example in examples] + instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images) batch = { "instance_images": instance_images, @@ -436,7 +438,8 @@ def main(): } if args.with_prior_preservation: - class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"]) + class_prompt_ids = [example["class_prompt_ids"] for example in examples] + class_images = [example["class_images"] for example in examples] batch["class_images"] = class_images batch["class_prompt_ids"] = class_prompt_ids