From b1dad2e9d33b9c70bdadcc1263be47ed11699dc1 Mon Sep 17 00:00:00 2001 From: wfng92 <43742196+wfng92@users.noreply.github.com> Date: Tue, 7 Feb 2023 18:58:31 +0800 Subject: [PATCH] Make center crop and random flip as args for unconditional image generation (#2259) * Add center crop and horizontal flip to args * Update command to use center crop and random flip * Add center crop and horizontal flip to args * Update command to use center crop and random flip --- .../unconditional_image_generation/README.md | 4 +- .../train_unconditional.py | 37 +++++++++++-------- .../unconditional_image_generation/README.md | 6 +-- .../train_unconditional.py | 37 +++++++++++-------- 4 files changed, 49 insertions(+), 35 deletions(-) diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/README.md b/examples/research_projects/onnxruntime/unconditional_image_generation/README.md index 7bf2ca443c..621e9a2fd6 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/README.md +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/README.md @@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru ```bash accelerate launch train_unconditional_ort.py \ --dataset_name="huggan/flowers-102-categories" \ - --resolution=64 \ + --resolution=64 --center_crop --random_flip \ --output_dir="ddpm-ema-flowers-64" \ --use_ema \ --train_batch_size=16 \ @@ -47,4 +47,4 @@ accelerate launch train_unconditional_ort.py \ --mixed_precision=fp16 ``` -Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. \ No newline at end of file +Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions. diff --git a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py index 80306dfccc..b26f2218f4 100644 --- a/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py +++ b/examples/research_projects/onnxruntime/unconditional_image_generation/train_unconditional.py @@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, create_repo, whoami from onnxruntime.training.ortmodule import ORTModule -from torchvision.transforms import ( - CenterCrop, - Compose, - InterpolationMode, - Normalize, - RandomHorizontalFlip, - Resize, - ToTensor, -) +from torchvision import transforms from tqdm.auto import tqdm @@ -105,6 +97,21 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="whether to randomly flip images horizontally", + ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -369,13 +376,13 @@ def main(args): # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets and DataLoaders creation. - augmentations = Compose( + augmentations = transforms.Compose( [ - Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - CenterCrop(args.resolution), - RandomHorizontalFlip(), - ToTensor(), - Normalize([0.5], [0.5]), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), ] ) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index c8b933d533..db06d90116 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset: ```bash accelerate launch train_unconditional.py \ --dataset_name="huggan/flowers-102-categories" \ - --resolution=64 \ + --resolution=64 --center_crop --random_flip \ --output_dir="ddpm-ema-flowers-64" \ --train_batch_size=16 \ --num_epochs=100 \ @@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset: ```bash accelerate launch train_unconditional.py \ --dataset_name="huggan/pokemon" \ - --resolution=64 \ + --resolution=64 --center_crop --random_flip \ --output_dir="ddpm-ema-pokemon-64" \ --train_batch_size=16 \ --num_epochs=100 \ @@ -139,4 +139,4 @@ dataset.push_to_hub("name_of_your_dataset", private=True) and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub. -More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). \ No newline at end of file +More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets). diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3ca92717f8..9a72463bb3 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler from diffusers.training_utils import EMAModel from diffusers.utils import check_min_version from huggingface_hub import HfFolder, Repository, create_repo, whoami -from torchvision.transforms import ( - CenterCrop, - Compose, - InterpolationMode, - Normalize, - RandomHorizontalFlip, - Resize, - ToTensor, -) +from torchvision import transforms from tqdm.auto import tqdm @@ -105,6 +97,21 @@ def parse_args(): " resolution" ), ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + default=False, + action="store_true", + help="whether to randomly flip images horizontally", + ) parser.add_argument( "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." ) @@ -369,13 +376,13 @@ def main(args): # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets and DataLoaders creation. - augmentations = Compose( + augmentations = transforms.Compose( [ - Resize(args.resolution, interpolation=InterpolationMode.BILINEAR), - CenterCrop(args.resolution), - RandomHorizontalFlip(), - ToTensor(), - Normalize([0.5], [0.5]), + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), ] )