From 0c63c3839a8dbaf336f640db3ddc8462d4f6711a Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Tue, 4 Apr 2023 07:37:47 -1000 Subject: [PATCH] allow use custom local dataset for controlnet training scripts (#2928) use custom local datset Co-authored-by: yiyixuxu Co-authored-by: Patrick von Platen --- examples/controlnet/train_controlnet.py | 13 +++++-------- examples/controlnet/train_controlnet_flax.py | 13 +++++-------- 2 files changed, 10 insertions(+), 16 deletions(-) diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index b38b62c3e7..20c4fbe189 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -542,16 +542,13 @@ def make_train_dataset(args, tokenizer, accelerator): cache_dir=args.cache_dir, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets. diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 47944358e4..6181387fc8 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -477,16 +477,13 @@ def make_train_dataset(args, tokenizer, batch_size=None): streaming=args.streaming, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets.