mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Map speedup (#6745)
* Speed up dataset mapping * Fix missing columns * Remove cache files cleanup * Update examples/text_to_image/train_text_to_image_sdxl.py * make style * Fix code style * style * Empty-Commit --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Quentin Lhoest <42851186+lhoestq@users.noreply.github.com> Co-authored-by: Quentin Lhoest <lhoest.q@gmail.com>
This commit is contained in:
@@ -35,7 +35,7 @@ import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from datasets import concatenate_datasets, load_dataset
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
@@ -896,13 +896,19 @@ def main(args):
|
||||
# details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401
|
||||
new_fingerprint = Hasher.hash(args)
|
||||
new_fingerprint_for_vae = Hasher.hash(vae_path)
|
||||
train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint)
|
||||
train_dataset = train_dataset.map(
|
||||
train_dataset_with_embeddings = train_dataset.map(
|
||||
compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint
|
||||
)
|
||||
train_dataset_with_vae = train_dataset.map(
|
||||
compute_vae_encodings_fn,
|
||||
batched=True,
|
||||
batch_size=args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps,
|
||||
new_fingerprint=new_fingerprint_for_vae,
|
||||
)
|
||||
precomputed_dataset = concatenate_datasets(
|
||||
[train_dataset_with_embeddings, train_dataset_with_vae.remove_columns(["image", "text"])], axis=1
|
||||
)
|
||||
precomputed_dataset = precomputed_dataset.with_transform(preprocess_train)
|
||||
|
||||
del text_encoders, tokenizers, vae
|
||||
gc.collect()
|
||||
@@ -925,7 +931,7 @@ def main(args):
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
precomputed_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
@@ -976,7 +982,7 @@ def main(args):
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num examples = {len(precomputed_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
|
||||
Reference in New Issue
Block a user