diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 6181387fc8..8d316fd048 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -27,13 +27,13 @@ import optax import torch import torch.utils.checkpoint import transformers -from datasets import load_dataset +from datasets import load_dataset, load_from_disk from flax import jax_utils from flax.core.frozen_dict import unfreeze from flax.training import train_state from flax.training.common_utils import shard from huggingface_hub import create_repo, upload_folder -from PIL import Image +from PIL import Image, PngImagePlugin from torch.utils.data import IterableDataset from torchvision import transforms from tqdm.auto import tqdm @@ -49,6 +49,11 @@ from diffusers import ( from diffusers.utils import check_min_version, is_wandb_available +# To prevent an error that occurs when there are abnormally large compressed data chunk in the png image +# see more https://github.com/python-pillow/Pillow/issues/5610 +LARGE_ENOUGH_NUMBER = 100 +PngImagePlugin.MAX_TEXT_CHUNK = LARGE_ENOUGH_NUMBER * (1024**2) + if is_wandb_available(): import wandb @@ -246,6 +251,12 @@ def parse_args(): default=None, help="Total number of training steps to perform.", ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=5000, + help=("Save a checkpoint of the training state every X updates."), + ) parser.add_argument( "--learning_rate", type=float, @@ -344,9 +355,17 @@ def parse_args(): type=str, default=None, help=( - "A folder containing the training data. Folder contents must follow the structure described in" - " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" - " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + "A folder containing the training dataset. By default it will use `load_dataset` method to load a custom dataset from the folder." + "Folder must contain a dataset script as described here https://huggingface.co/docs/datasets/dataset_script) ." + "If `--load_from_disk` flag is passed, it will use `load_from_disk` method instead. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--load_from_disk", + action="store_true", + help=( + "If True, will load a dataset that was previously saved using `save_to_disk` from `--train_data_dir`" + "See more https://huggingface.co/docs/datasets/package_reference/main_classes#datasets.Dataset.load_from_disk" ), ) parser.add_argument( @@ -478,10 +497,15 @@ def make_train_dataset(args, tokenizer, batch_size=None): ) else: if args.train_data_dir is not None: - dataset = load_dataset( - args.train_data_dir, - cache_dir=args.cache_dir, - ) + if args.load_from_disk: + dataset = load_from_disk( + args.train_data_dir, + ) + else: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script @@ -545,6 +569,7 @@ def make_train_dataset(args, tokenizer, batch_size=None): image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] @@ -553,6 +578,7 @@ def make_train_dataset(args, tokenizer, batch_size=None): conditioning_image_transforms = transforms.Compose( [ transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), transforms.ToTensor(), ] ) @@ -981,6 +1007,11 @@ def main(): "train/loss": jax_utils.unreplicate(train_metric)["loss"], } ) + if global_step % args.checkpointing_steps == 0 and jax.process_index() == 0: + controlnet.save_pretrained( + f"{args.output_dir}/{global_step}", + params=get_params_to_save(state.params), + ) train_metric = jax_utils.unreplicate(train_metric) train_step_progress_bar.close()