From 0deeb06aac1d4303029b208331d8b04080bb5c0a Mon Sep 17 00:00:00 2001 From: anton-l Date: Wed, 15 Jun 2022 14:36:43 +0200 Subject: [PATCH] better defaults --- examples/train_ddpm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/train_ddpm.py b/examples/train_ddpm.py index 7eb0b9d34e..cff515efad 100644 --- a/examples/train_ddpm.py +++ b/examples/train_ddpm.py @@ -12,7 +12,7 @@ from torchvision.transforms import ( Compose, InterpolationMode, Lambda, - RandomCrop, + CenterCrop, RandomHorizontalFlip, Resize, ToTensor, @@ -39,7 +39,7 @@ def main(args): augmentations = Compose( [ Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - RandomCrop(args.resolution), + CenterCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), Lambda(lambda x: x * 2 - 1), @@ -136,7 +136,7 @@ if __name__ == "__main__": parser.add_argument("--output_path", type=str, default="ddpm-model") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--gradient_accumulation_steps", type=int, default=2) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument(