From ce08cb72fbe94d571acaac0c9c686c341b855c45 Mon Sep 17 00:00:00 2001 From: Ruizhe Wang <88331091+Mr-Philo@users.noreply.github.com> Date: Fri, 10 Mar 2023 21:15:16 +0800 Subject: [PATCH] [Dreambooth] Editable number of class images (#2251) * [Dreambooth] Editable number of class images * 'class_num=None' bug fix --------- Co-authored-by: Patrick von Platen --- examples/dreambooth/train_dreambooth.py | 7 ++++++- examples/dreambooth/train_dreambooth_flax.py | 7 ++++++- examples/dreambooth/train_dreambooth_lora.py | 7 ++++++- 3 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 4d921900c0..0ccde2fc5e 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -454,6 +454,7 @@ class DreamBoothDataset(Dataset): tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -474,7 +475,10 @@ class DreamBoothDataset(Dataset): self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -814,6 +818,7 @@ def main(args): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 9dcd20939c..46edd5399e 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -231,6 +231,7 @@ class DreamBoothDataset(Dataset): tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -251,7 +252,10 @@ class DreamBoothDataset(Dataset): self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -419,6 +423,7 @@ def main(): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop, diff --git a/examples/dreambooth/train_dreambooth_lora.py b/examples/dreambooth/train_dreambooth_lora.py index c932198232..92d08b64b6 100644 --- a/examples/dreambooth/train_dreambooth_lora.py +++ b/examples/dreambooth/train_dreambooth_lora.py @@ -417,6 +417,7 @@ class DreamBoothDataset(Dataset): tokenizer, class_data_root=None, class_prompt=None, + class_num=None, size=512, center_crop=False, ): @@ -437,7 +438,10 @@ class DreamBoothDataset(Dataset): self.class_data_root = Path(class_data_root) self.class_data_root.mkdir(parents=True, exist_ok=True) self.class_images_path = list(self.class_data_root.iterdir()) - self.num_class_images = len(self.class_images_path) + if class_num is not None: + self.num_class_images = min(len(self.class_images_path), class_num) + else: + self.num_class_images = len(self.class_images_path) self._length = max(self.num_class_images, self.num_instance_images) self.class_prompt = class_prompt else: @@ -771,6 +775,7 @@ def main(args): instance_prompt=args.instance_prompt, class_data_root=args.class_data_dir if args.with_prior_preservation else None, class_prompt=args.class_prompt, + class_num=args.num_class_images, tokenizer=tokenizer, size=args.resolution, center_crop=args.center_crop,