1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Add --dataloader_num_workers to the DDPM training example (#1027)

This commit is contained in:
Anton Lozhkov
2022-10-27 15:55:36 +02:00
committed by GitHub
parent 939ec17e91
commit a6314a8d4e

View File

@@ -83,7 +83,16 @@ def parse_args():
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader."
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
" process."
),
)
parser.add_argument("--num_epochs", type=int, default=100)
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
@@ -249,7 +258,9 @@ def main(args):
return {"input": images}
dataset.set_transform(transforms)
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
train_dataloader = torch.utils.data.DataLoader(
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
)
lr_scheduler = get_scheduler(
args.lr_scheduler,