diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 6a1495b34c..2bc8114cac 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -83,7 +83,16 @@ def parse_args(): "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) parser.add_argument( - "--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader." + "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" + " process." + ), ) parser.add_argument("--num_epochs", type=int, default=100) parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.") @@ -249,7 +258,9 @@ def main(args): return {"input": images} dataset.set_transform(transforms) - train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True) + train_dataloader = torch.utils.data.DataLoader( + dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers + ) lr_scheduler = get_scheduler( args.lr_scheduler,