1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

fix collate_fn

This commit is contained in:
patil-suraj
2022-09-26 14:35:40 +02:00
parent 392fbf3c1d
commit d6c88f4ef6

View File

@@ -428,7 +428,9 @@ def main():
).input_ids
return input_ids, pixel_values
instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"])
instance_prompt_ids = [example["instance_prompt_ids"] for example in examples]
instance_images = [example["instance_images"] for example in examples]
instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images)
batch = {
"instance_images": instance_images,
@@ -436,7 +438,8 @@ def main():
}
if args.with_prior_preservation:
class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"])
class_prompt_ids = [example["class_prompt_ids"] for example in examples]
class_images = [example["class_images"] for example in examples]
batch["class_images"] = class_images
batch["class_prompt_ids"] = class_prompt_ids