diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index b38b62c3e7..20c4fbe189 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -542,16 +542,13 @@ def make_train_dataset(args, tokenizer, accelerator): cache_dir=args.cache_dir, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets. diff --git a/examples/controlnet/train_controlnet_flax.py b/examples/controlnet/train_controlnet_flax.py index 47944358e4..6181387fc8 100644 --- a/examples/controlnet/train_controlnet_flax.py +++ b/examples/controlnet/train_controlnet_flax.py @@ -477,16 +477,13 @@ def make_train_dataset(args, tokenizer, batch_size=None): streaming=args.streaming, ) else: - data_files = {} if args.train_data_dir is not None: - data_files["train"] = os.path.join(args.train_data_dir, "**") - dataset = load_dataset( - "imagefolder", - data_files=data_files, - cache_dir=args.cache_dir, - ) + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) # See more about loading custom images at - # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script # Preprocessing the datasets. # We need to tokenize inputs and targets.