1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Make center crop and random flip as args for unconditional image generation (#2259)

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip

* Add center crop and horizontal flip to args

* Update command to use center crop and random flip
This commit is contained in:
wfng92
2023-02-07 18:58:31 +08:00
committed by GitHub
parent cd52475560
commit b1dad2e9d3
4 changed files with 49 additions and 35 deletions

View File

@@ -36,7 +36,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset with onnxru
```bash
accelerate launch train_unconditional_ort.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \
--use_ema \
--train_batch_size=16 \
@@ -47,4 +47,4 @@ accelerate launch train_unconditional_ort.py \
--mixed_precision=fp16
```
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.
Please contact Prathik Rao (prathikr), Sunghoon Choi (hanbitmyths), Ashwini Khade (askhade), or Peng Wang (pengwa) on github with any questions.

View File

@@ -20,15 +20,7 @@ from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from onnxruntime.training.ortmodule import ORTModule
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from torchvision import transforms
from tqdm.auto import tqdm
@@ -105,6 +97,21 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
augmentations = Compose(
augmentations = transforms.Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

View File

@@ -34,7 +34,7 @@ The command to train a DDPM UNet model on the Oxford Flowers dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/flowers-102-categories" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-flowers-64" \
--train_batch_size=16 \
--num_epochs=100 \
@@ -59,7 +59,7 @@ The command to train a DDPM UNet model on the Pokemon dataset:
```bash
accelerate launch train_unconditional.py \
--dataset_name="huggan/pokemon" \
--resolution=64 \
--resolution=64 --center_crop --random_flip \
--output_dir="ddpm-ema-pokemon-64" \
--train_batch_size=16 \
--num_epochs=100 \
@@ -139,4 +139,4 @@ dataset.push_to_hub("name_of_your_dataset", private=True)
and that's it! You can now train your model by simply setting the `--dataset_name` argument to the name of your dataset on the hub.
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).
More on this can also be found in [this blog post](https://huggingface.co/blog/image-search-datasets).

View File

@@ -19,15 +19,7 @@ from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from torchvision.transforms import (
CenterCrop,
Compose,
InterpolationMode,
Normalize,
RandomHorizontalFlip,
Resize,
ToTensor,
)
from torchvision import transforms
from tqdm.auto import tqdm
@@ -105,6 +97,21 @@ def parse_args():
" resolution"
),
)
parser.add_argument(
"--center_crop",
default=False,
action="store_true",
help=(
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--random_flip",
default=False,
action="store_true",
help="whether to randomly flip images horizontally",
)
parser.add_argument(
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
)
@@ -369,13 +376,13 @@ def main(args):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
# Preprocessing the datasets and DataLoaders creation.
augmentations = Compose(
augmentations = transforms.Compose(
[
Resize(args.resolution, interpolation=InterpolationMode.BILINEAR),
CenterCrop(args.resolution),
RandomHorizontalFlip(),
ToTensor(),
Normalize([0.5], [0.5]),
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)