From 25f850a23bf7cce15f308b157b237554bbfdc8ed Mon Sep 17 00:00:00 2001 From: Will Berman Date: Fri, 2 Dec 2022 03:12:28 -0800 Subject: [PATCH] [docs] [dreambooth training] num_class_images clarification (#1508) --- examples/dreambooth/README.md | 2 +- examples/dreambooth/train_dreambooth.py | 4 ++-- examples/dreambooth/train_dreambooth_flax.py | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 9f89cd31d2..b68ca6d367 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -70,7 +70,7 @@ accelerate launch train_dreambooth.py \ ### Training with prior-preservation loss Prior-preservation is used to avoid overfitting and language-drift. Refer to the paper to learn more about it. For prior-preservation we first generate images using the model with a class prompt and then use those during training along with our data. -According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. +According to the paper, it's recommended to generate `num_epochs * num_samples` images for prior-preservation. 200-300 works well for most cases. The `num_class_images` flag sets the number of images to generate with the class prompt. You can place existing images in `class_data_dir`, and the training script will generate any additional images so that `num_class_images` are present in `class_data_dir` during training time. ```bash export MODEL_NAME="CompVis/stable-diffusion-v1-4" diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index cbd000ac8d..ccacc46679 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -107,8 +107,8 @@ def parse_args(input_args=None): type=int, default=100, help=( - "Minimal class images for prior preservation loss. If not have enough images, additional images will be" - " sampled with class_prompt." + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." ), ) parser.add_argument( diff --git a/examples/dreambooth/train_dreambooth_flax.py b/examples/dreambooth/train_dreambooth_flax.py index 6606af4f17..cb751b09ac 100644 --- a/examples/dreambooth/train_dreambooth_flax.py +++ b/examples/dreambooth/train_dreambooth_flax.py @@ -89,8 +89,8 @@ def parse_args(): type=int, default=100, help=( - "Minimal class images for prior preservation loss. If not have enough images, additional images will be" - " sampled with class_prompt." + "Minimal class images for prior preservation loss. If there are not enough images already present in" + " class_data_dir, additional images will be sampled with class_prompt." ), ) parser.add_argument(