1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Map speedup (#6745)

* Speed up dataset mapping

* Fix missing columns

* Remove cache files cleanup

* Update examples/text_to_image/train_text_to_image_sdxl.py

* make style

* Fix code style

* style

* Empty-Commit

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com>
Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
Oleh
2024-03-01 18:08:21 +02:00
committed by GitHub
parent 5f150c4cef
commit 9a2600ede9

View File

@@ -35,7 +35,7 @@ import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from datasets import concatenate_datasets, load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
@@ -896,13 +896,19 @@ def main(args):
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
new_fingerprint = Hasher.hash(args)
new_fingerprint_for_vae = Hasher.hash(vae_path)
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
train_dataset = train_dataset.map(
train_dataset_with_embeddings = train_dataset.map(
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
)
train_dataset_with_vae = train_dataset.map(
compute_vae_encodings_fn,
batched=True,
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
new_fingerprint=new_fingerprint_for_vae,
)
precomputed_dataset = concatenate_datasets(
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
)
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
del text_encoders, tokenizers, vae
gc.collect()
@@ -925,7 +931,7 @@ def main(args):
# DataLoaders creation:
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
precomputed_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
@@ -976,7 +982,7 @@ def main(args):
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num examples = {len(precomputed_dataset)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")