From a6314a8d4e1c301bce4e45c10f325f594220617f Mon Sep 17 00:00:00 2001 From: Anton Lozhkov Date: Thu, 27 Oct 2022 15:55:36 +0200 Subject: [PATCH] Add `--dataloader_num_workers` to the DDPM training example (#1027) --- .../train_unconditional.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 6a1495b34c..2bc8114cac 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -83,7 +83,16 @@ def parse_args(): "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), ) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") @@ -249,7 +258,9 @@ def main(args): return {"input": images} dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) lr_scheduler = get_scheduler( args.lr_scheduler,