mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples] Add streaming support to the ControlNet training example in JAX (#2859)
* improve stable unclip doc. * feat: add streaming support to controlnet flax training script. * fix: CLI arg. * fix: torch dataloader shuffle setting. * fix: dataset length. * fix: wandb config. * fix: steps_per_epoch in the training loop. * add: entry about streaming in the readme * get column names from iterable dataset + fix final logging --------- Co-authored-by: yiyixuxu <yixu310@gmail.com>
This commit is contained in:
@@ -335,7 +335,7 @@ huggingface-cli login
|
||||
|
||||
Make sure you have the `MODEL_DIR`,`OUTPUT_DIR` and `HUB_MODEL_ID` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables specify where to save the model to on the Hub:
|
||||
|
||||
```
|
||||
```bash
|
||||
export MODEL_DIR="runwayml/stable-diffusion-v1-5"
|
||||
export OUTPUT_DIR="control_out"
|
||||
export HUB_MODEL_ID="fill-circle-controlnet"
|
||||
@@ -343,7 +343,7 @@ export HUB_MODEL_ID="fill-circle-controlnet"
|
||||
|
||||
And finally start the training
|
||||
|
||||
```
|
||||
```bash
|
||||
python3 train_controlnet_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
@@ -363,3 +363,30 @@ python3 train_controlnet_flax.py \
|
||||
```
|
||||
|
||||
Since we passed the `--push_to_hub` flag, it will automatically create a model repo under your huggingface account based on `$HUB_MODEL_ID`. By the end of training, the final checkpoint will be automatically stored on the hub. You can find an example model repo [here](https://huggingface.co/YiYiXu/fill-circle-controlnet).
|
||||
|
||||
Our training script also provides limited support for streaming large datasets from the Hugging Face Hub. In order to enable streaming, one must also set `--max_train_samples`. Here is an example command:
|
||||
|
||||
```bash
|
||||
python3 train_controlnet_flax.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--dataset_name=multimodalart/facesyntheticsspigacaptioned \
|
||||
--streaming \
|
||||
--conditioning_image_column=spiga_seg \
|
||||
--image_column=image \
|
||||
--caption_column=image_caption \
|
||||
--resolution=512 \
|
||||
--max_train_samples 50 \
|
||||
--max_train_steps 5 \
|
||||
--learning_rate=1e-5 \
|
||||
--validation_steps=2 \
|
||||
--train_batch_size=1 \
|
||||
--revision="flax" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Note, however, that the performance of the TPUs might get bottlenecked as streaming with `datasets` is not optimized for images. For ensuring maximum throughput, we encourage you to explore the following options:
|
||||
|
||||
* [Webdataset](https://webdataset.github.io/webdataset/)
|
||||
* [TorchData](https://github.com/pytorch/data)
|
||||
* [TensorFlow Datasets](https://www.tensorflow.org/datasets/tfless_tfds)
|
||||
@@ -35,6 +35,7 @@ from flax.training import train_state
|
||||
from flax.training.common_utils import shard
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from PIL import Image
|
||||
from torch.utils.data import IterableDataset
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import CLIPTokenizer, FlaxCLIPTextModel, set_seed
|
||||
@@ -206,7 +207,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--from_pt",
|
||||
action="store_true",
|
||||
help="Load the pretrained model from a pytorch checkpoint.",
|
||||
help="Load the pretrained model from a PyTorch checkpoint.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
@@ -332,6 +333,7 @@ def parse_args():
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--streaming", action="store_true", help="To stream a large dataset from Hub.")
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
@@ -369,7 +371,7 @@ def parse_args():
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
"value if set. Needed if `streaming` is set to True."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -453,10 +455,15 @@ def parse_args():
|
||||
" or the same number of `--validation_prompt`s and `--validation_image`s"
|
||||
)
|
||||
|
||||
# This idea comes from
|
||||
# https://github.com/borisdayma/dalle-mini/blob/d2be512d4a6a9cda2d63ba04afc33038f98f705f/src/dalle_mini/data.py#L370
|
||||
if args.streaming and args.max_train_samples is None:
|
||||
raise ValueError("You must specify `max_train_samples` when using dataset streaming.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def make_train_dataset(args, tokenizer):
|
||||
def make_train_dataset(args, tokenizer, batch_size=None):
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
@@ -468,6 +475,7 @@ def make_train_dataset(args, tokenizer):
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
streaming=args.streaming,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
@@ -483,7 +491,10 @@ def make_train_dataset(args, tokenizer):
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
if isinstance(dataset["train"], IterableDataset):
|
||||
column_names = next(iter(dataset["train"])).keys()
|
||||
else:
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
@@ -565,9 +576,20 @@ def make_train_dataset(args, tokenizer):
|
||||
|
||||
if jax.process_index() == 0:
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
if args.streaming:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).take(args.max_train_samples)
|
||||
else:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
if args.streaming:
|
||||
train_dataset = dataset["train"].map(
|
||||
preprocess_train,
|
||||
batched=True,
|
||||
batch_size=batch_size,
|
||||
remove_columns=list(dataset["train"].features.keys()),
|
||||
)
|
||||
else:
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
return train_dataset
|
||||
|
||||
@@ -661,12 +683,12 @@ def main():
|
||||
raise NotImplementedError("No tokenizer specified!")
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
train_dataset = make_train_dataset(args, tokenizer)
|
||||
total_train_batch_size = args.train_batch_size * jax.local_device_count() * args.gradient_accumulation_steps
|
||||
train_dataset = make_train_dataset(args, tokenizer, batch_size=total_train_batch_size)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
shuffle=not args.streaming,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=total_train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
@@ -897,7 +919,11 @@ def main():
|
||||
vae_params = jax_utils.replicate(vae_params)
|
||||
|
||||
# Train!
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.streaming:
|
||||
dataset_length = args.max_train_samples
|
||||
else:
|
||||
dataset_length = len(train_dataloader)
|
||||
num_update_steps_per_epoch = math.ceil(dataset_length / args.gradient_accumulation_steps)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
if args.max_train_steps is None:
|
||||
@@ -906,7 +932,7 @@ def main():
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num examples = {args.max_train_samples if args.streaming else len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel & distributed) = {total_train_batch_size}")
|
||||
@@ -916,7 +942,7 @@ def main():
|
||||
wandb.define_metric("*", step_metric="train/step")
|
||||
wandb.config.update(
|
||||
{
|
||||
"num_train_examples": len(train_dataset),
|
||||
"num_train_examples": args.max_train_samples if args.streaming else len(train_dataset),
|
||||
"total_train_batch_size": total_train_batch_size,
|
||||
"total_optimization_step": args.num_train_epochs * num_update_steps_per_epoch,
|
||||
"num_devices": jax.device_count(),
|
||||
@@ -935,7 +961,11 @@ def main():
|
||||
|
||||
train_metrics = []
|
||||
|
||||
steps_per_epoch = len(train_dataset) // total_train_batch_size
|
||||
steps_per_epoch = (
|
||||
args.max_train_samples // total_train_batch_size
|
||||
if args.streaming
|
||||
else len(train_dataset) // total_train_batch_size
|
||||
)
|
||||
train_step_progress_bar = tqdm(
|
||||
total=steps_per_epoch,
|
||||
desc="Training...",
|
||||
@@ -980,7 +1010,8 @@ def main():
|
||||
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if jax.process_index() == 0:
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
if args.validation_prompt is not None:
|
||||
image_logs = log_validation(controlnet, state.params, tokenizer, args, validation_rng, weight_dtype)
|
||||
|
||||
controlnet.save_pretrained(
|
||||
args.output_dir,
|
||||
|
||||
Reference in New Issue
Block a user