mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
fix batch unmarshalling
This commit is contained in:
@@ -1066,7 +1066,9 @@ def main(args):
|
||||
|
||||
# 15. Prepare for training
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, train_dataloader, lr_scheduler)
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(train_dataloader.num_batches / args.gradient_accumulation_steps)
|
||||
@@ -1136,7 +1138,12 @@ def main(args):
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
image, text, orig_size, crop_coords = batch
|
||||
image, text, orig_size, crop_coords = (
|
||||
batch["pixel_values"],
|
||||
batch["captions"],
|
||||
batch["original_sizes"],
|
||||
batch["crop_top_lefts"],
|
||||
)
|
||||
|
||||
image = image.to(accelerator.device, non_blocking=True)
|
||||
encoded_text = compute_embeddings_fn(text, orig_size, crop_coords)
|
||||
|
||||
Reference in New Issue
Block a user