mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add --dataloader_num_workers to the DDPM training example (#1027)
This commit is contained in:
@@ -83,7 +83,16 @@ def parse_args():
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--eval_batch_size", type=int, default=16, help="Batch size (per device) for the eval dataloader."
|
||||
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
|
||||
" process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--num_epochs", type=int, default=100)
|
||||
parser.add_argument("--save_images_epochs", type=int, default=10, help="How often to save images during training.")
|
||||
@@ -249,7 +258,9 @@ def main(args):
|
||||
return {"input": images}
|
||||
|
||||
dataset.set_transform(transforms)
|
||||
train_dataloader = torch.utils.data.DataLoader(dataset, batch_size=args.train_batch_size, shuffle=True)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
dataset, batch_size=args.train_batch_size, shuffle=True, num_workers=args.dataloader_num_workers
|
||||
)
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
|
||||
Reference in New Issue
Block a user