diff --git a/examples/controlnet/train_controlnet.py b/examples/controlnet/train_controlnet.py index a2aa266cdf..1ddddd18b6 100644 --- a/examples/controlnet/train_controlnet.py +++ b/examples/controlnet/train_controlnet.py @@ -571,9 +571,6 @@ def parse_args(input_args=None): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") - if args.dataset_name is not None and args.train_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -615,6 +612,7 @@ def make_train_dataset(args, tokenizer, accelerator): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir, ) else: if args.train_data_dir is not None: diff --git a/examples/controlnet/train_controlnet_sdxl.py b/examples/controlnet/train_controlnet_sdxl.py index c034c027cb..df4ef0f7dd 100644 --- a/examples/controlnet/train_controlnet_sdxl.py +++ b/examples/controlnet/train_controlnet_sdxl.py @@ -598,9 +598,6 @@ def parse_args(input_args=None): if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") - if args.dataset_name is not None and args.train_data_dir is not None: - raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -642,6 +639,7 @@ def get_train_dataset(args, accelerator): args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, + data_dir=args.train_data_dir, ) else: if args.train_data_dir is not None: diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index b34feb6f71..398e793c04 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -483,7 +483,6 @@ def parse_args(input_args=None): # Sanity checks if args.dataset_name is None and args.train_data_dir is None: raise ValueError("Need either a dataset name or a training folder.") - if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") @@ -824,9 +823,7 @@ def main(args): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {}