1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

update flax controlnet training script (#2951)

* load_from_disk + checkpointing_steps

* apply feedback
This commit is contained in:
YiYi Xu
2023-04-04 15:49:44 -10:00
committed by Daniel Gu
parent 72f37316ee
commit fe03d5bce4

View File

@@ -27,13 +27,13 @@ import optax
import torch
import torch.utils.checkpoint
import transformers
from datasets import load_dataset
from datasets import load_dataset, load_from_disk
from flax import jax_utils
from flax.core.frozen_dict import unfreeze
from flax.training import train_state
from flax.training.common_utils import shard
from huggingface_hub import create_repo, upload_folder
from PIL import Image
from PIL import Image, PngImagePlugin
from torch.utils.data import IterableDataset
from torchvision import transforms
from tqdm.auto import tqdm
@@ -49,6 +49,11 @@ from diffusers import (
from diffusers.utils import check_min_version, is_wandb_available
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
# see more https://github.com/python-pillow/Pillow/issues/5610
LARGE_ENOUGH_NUMBER = 100
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
if is_wandb_available():
import wandb
@@ -246,6 +251,12 @@ def parse_args():
default=None,
help="Total number of training steps to perform.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=5000,
help=("Save a checkpoint of the training state every X updates."),
)
parser.add_argument(
"--learning_rate",
type=float,
@@ -344,9 +355,17 @@ def parse_args():
type=str,
default=None,
help=(
"A folder containing the training data. Folder contents must follow the structure described in"
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
),
)
parser.add_argument(
"--load_from_disk",
action="store_true",
help=(
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
),
)
parser.add_argument(
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
)
else:
if args.train_data_dir is not None:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
if args.load_from_disk:
dataset = load_from_disk(
args.train_data_dir,
)
else:
dataset = load_dataset(
args.train_data_dir,
cache_dir=args.cache_dir,
)
# See more about loading custom images at
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
conditioning_image_transforms = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution),
transforms.ToTensor(),
]
)
@@ -981,6 +1007,11 @@ def main():
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
}
)
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
controlnet.save_pretrained(
f"{args.output_dir}/{global_step}",
params=get_params_to_save(state.params),
)
train_metric = jax_utils.unreplicate(train_metric)
train_step_progress_bar.close()