mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix collate_fn
This commit is contained in:
@@ -428,7 +428,9 @@ def main():
|
||||
).input_ids
|
||||
return input_ids, pixel_values
|
||||
|
||||
instance_prompt_ids, instance_images = _collate(examples["instance_prompt_ids"], examples["instance_images"])
|
||||
instance_prompt_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
instance_images = [example["instance_images"] for example in examples]
|
||||
instance_prompt_ids, instance_images = _collate(instance_prompt_ids, instance_images)
|
||||
|
||||
batch = {
|
||||
"instance_images": instance_images,
|
||||
@@ -436,7 +438,8 @@ def main():
|
||||
}
|
||||
|
||||
if args.with_prior_preservation:
|
||||
class_prompt_ids, class_images = _collate(examples["class_prompt_ids"], examples["class_images"])
|
||||
class_prompt_ids = [example["class_prompt_ids"] for example in examples]
|
||||
class_images = [example["class_images"] for example in examples]
|
||||
batch["class_images"] = class_images
|
||||
batch["class_prompt_ids"] = class_prompt_ids
|
||||
|
||||
|
||||
Reference in New Issue
Block a user