diff --git a/examples/train_ddpm.py b/examples/train_ddpm.py index 7eb0b9d34e..cff515efad 100644 --- a/examples/train_ddpm.py +++ b/examples/train_ddpm.py @@ -12,7 +12,7 @@ from torchvision.transforms import ( Compose, InterpolationMode, Lambda, - RandomCrop, + CenterCrop, RandomHorizontalFlip, Resize, ToTensor, @@ -39,7 +39,7 @@ def main(args): augmentations = Compose( [ Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - RandomCrop(args.resolution), + CenterCrop(args.resolution), RandomHorizontalFlip(), ToTensor(), Lambda(lambda x: x * 2 - 1), @@ -136,7 +136,7 @@ if __name__ == "__main__": parser.add_argument("--output_path", type=str, default="ddpm-model") parser.add_argument("--batch_size", type=int, default=16) parser.add_argument("--num_epochs", type=int, default=100) - parser.add_argument("--gradient_accumulation_steps", type=int, default=2) + parser.add_argument("--gradient_accumulation_steps", type=int, default=1) parser.add_argument("--lr", type=float, default=1e-4) parser.add_argument("--warmup_steps", type=int, default=500) parser.add_argument(