mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
update flax controlnet training script (#2951)
* load_from_disk + checkpointing_steps * apply feedback
This commit is contained in:
@@ -27,13 +27,13 @@ import optax
|
||||
import torch
|
||||
import torch.utils.checkpoint
|
||||
import transformers
|
||||
from datasets import load_dataset
|
||||
from datasets import load_dataset, load_from_disk
|
||||
from flax import jax_utils
|
||||
from flax.core.frozen_dict import unfreeze
|
||||
from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from PIL import Image
|
||||
from PIL import Image, PngImagePlugin
|
||||
from torch.utils.data import IterableDataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
@@ -49,6 +49,11 @@ from diffusers import (
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
|
||||
|
||||
# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image
|
||||
# see more https://github.com/python-pillow/Pillow/issues/5610
|
||||
LARGE_ENOUGH_NUMBER = 100
|
||||
PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2)
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
@@ -246,6 +251,12 @@ def parse_args():
|
||||
default=None,
|
||||
help="Total number of training steps to perform.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=5000,
|
||||
help=("Save a checkpoint of the training state every X updates."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
@@ -344,9 +355,17 @@ def parse_args():
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
"A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder."
|
||||
"Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ."
|
||||
"If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--load_from_disk",
|
||||
action="store_true",
|
||||
help=(
|
||||
"If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`"
|
||||
"See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None):
|
||||
)
|
||||
else:
|
||||
if args.train_data_dir is not None:
|
||||
dataset = load_dataset(
|
||||
args.train_data_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
if args.load_from_disk:
|
||||
dataset = load_from_disk(
|
||||
args.train_data_dir,
|
||||
)
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
args.train_data_dir,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
|
||||
@@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
|
||||
image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
@@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None):
|
||||
conditioning_image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
]
|
||||
)
|
||||
@@ -981,6 +1007,11 @@ def main():
|
||||
"train/loss": jax_utils.unreplicate(train_metric)["loss"],
|
||||
}
|
||||
)
|
||||
if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0:
|
||||
controlnet.save_pretrained(
|
||||
f"{args.output_dir}/{global_step}",
|
||||
params=get_params_to_save(state.params),
|
||||
)
|
||||
|
||||
train_metric = jax_utils.unreplicate(train_metric)
|
||||
train_step_progress_bar.close()
|
||||
|
||||
Reference in New Issue
Block a user