mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
Add features to the Dreambooth LoRA SDXL training script (#5508)
* Additions: - support for different lr for text encoder - support for Prodigy optimizer - support for min snr gamma - support for custom captions and dataset loading from the hub * adjusted --caption_column behaviour (to -not- use the second column of the dataset by default if --caption_column is not provided) * fixed --output_dir / --model_dir_name confusion * added --repeats, --adam_weight_decay_text_encoder + some fixes * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update examples/dreambooth/train_dreambooth_lora_sdxl.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * - import compute_snr from diffusers/training_utils.py - cluster adamw together - when using 'prodigy', if --train_text_encoder == True and --text_encoder_lr != --learning rate, changes the lr of the text encoders optimization params to be --learning_rate (otherwise errors) * shape fixes when custom captions are used * formatting and a little cleanup * code styling * --repeats default value fixed, changed to 1 * bug fix - removed redundant lines of embedding concatenation when using prior_preservation (that duplicated class_prompt embeddings) * changed dataset loading logic according to the following usecases (to avoid unnecessary dependency on datasets)- 1. user provides --dataset_name 2. user provides local dir --instance_data_dir that contains a metadata .jsonl file 3. user provides local dir --instance_data_dir that contains only images in cases [1,2] we import datasets and use load_dataset method, in case [3] we process the data same as in the original script setting * styling fix * arg name fix * adjusted the --repeats logic * -removed redundant arg and 'if' when loading local folder with prompts -updated readme template -some default val fixes -custom caption tests * image path fix for readme * code style * bug fix * --caption_column arg * readme fix --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
This commit is contained in:
@@ -52,7 +52,7 @@ from diffusers import (
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import unet_lora_state_dict
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
||||
@@ -64,36 +64,65 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
vae_path=None,
|
||||
):
|
||||
img_str = ""
|
||||
img_str = "widget:\n" if images else ""
|
||||
for i, image in enumerate(images):
|
||||
image.save(os.path.join(repo_folder, f"image_{i}.png"))
|
||||
img_str += f"\n"
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url: >-
|
||||
"image_{i}.png"
|
||||
"""
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: openrail++
|
||||
base_model: {base_model}
|
||||
instance_prompt: {prompt}
|
||||
tags:
|
||||
- stable-diffusion-xl
|
||||
- stable-diffusion-xl-diffusers
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
- template:sd-lora
|
||||
widget:
|
||||
{img_str}
|
||||
---
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
---
|
||||
"""
|
||||
|
||||
model_card = f"""
|
||||
# LoRA DreamBooth - {repo_id}
|
||||
# SDXL LoRA DreamBooth - {repo_id}
|
||||
|
||||
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
|
||||
{img_str}
|
||||
<Gallery />
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
|
||||
## Download model
|
||||
|
||||
Weights for this model are available in Safetensors format.
|
||||
|
||||
[Download]({repo_id}/tree/main) them in the Files & versions tab.
|
||||
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -141,13 +170,53 @@ def parse_args(input_args=None):
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) containing the training data of instance images (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
help=("A folder containing the training data. "),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--image_column",
|
||||
type=str,
|
||||
default="image",
|
||||
help="The column of the dataset containing the target image. By "
|
||||
"default, the standard Image Dataset maps out 'file_name' "
|
||||
"to 'image'.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The column of the dataset containing the instance prompt for each image",
|
||||
)
|
||||
|
||||
parser.add_argument("--repeats", type=int, default=1, help="How many times to repeat the training data.")
|
||||
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
@@ -160,7 +229,7 @@ def parse_args(input_args=None):
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
help="The prompt with identifier specifying the instance, e.g. 'photo of a TOK dog', 'in the style of TOK'",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
@@ -299,9 +368,16 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-4,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--text_encoder_lr",
|
||||
type=float,
|
||||
default=5e-6,
|
||||
help="Text encoder learning rate to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
@@ -317,6 +393,14 @@ def parse_args(input_args=None):
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
@@ -335,13 +419,59 @@ def parse_args(input_args=None):
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
"--optimizer",
|
||||
type=str,
|
||||
default="AdamW",
|
||||
help=('The optimizer type to use. Choose between ["AdamW", "prodigy"]'),
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam",
|
||||
action="store_true",
|
||||
help="Whether or not to use 8-bit Adam from bitsandbytes. Ignored if optimizer is not set to AdamW",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam and Prodigy optimizers."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_beta3",
|
||||
type=float,
|
||||
default=None,
|
||||
help="coefficients for computing the Prodidy stepsize using running averages. If set to None, "
|
||||
"uses the value of square root of beta2. Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--adam_epsilon",
|
||||
type=float,
|
||||
default=1e-08,
|
||||
help="Epsilon value for the Adam optimizer and Prodigy optimizers.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--prodigy_use_bias_correction",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Turn on Adam's bias correction. True by default. Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prodigy_safeguard_warmup",
|
||||
type=bool,
|
||||
default=True,
|
||||
help="Remove lr from the denominator of D estimate to avoid issues during warm-up stage. True by default. "
|
||||
"Ignored if optimizer is adamW",
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
@@ -414,6 +544,12 @@ def parse_args(input_args=None):
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.dataset_name is None and args.instance_data_dir is None:
|
||||
raise ValueError("Specify either `--dataset_name` or `--instance_data_dir`")
|
||||
|
||||
if args.dataset_name is not None and args.instance_data_dir is not None:
|
||||
raise ValueError("Specify only one of `--dataset_name` or `--instance_data_dir`")
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
@@ -442,20 +578,84 @@ class DreamBoothDataset(Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_prompt,
|
||||
class_data_root=None,
|
||||
class_num=None,
|
||||
size=1024,
|
||||
repeats=1,
|
||||
center_crop=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
self.instance_prompt = instance_prompt
|
||||
self.custom_instance_prompts = None
|
||||
self.class_prompt = class_prompt
|
||||
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
# if --dataset_name is provided or a metadata jsonl file is provided in the local --instance_data directory,
|
||||
# we load the training data using load_dataset
|
||||
if args.dataset_name is not None:
|
||||
try:
|
||||
from datasets import load_dataset
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"You are trying to load your data using the datasets library. If you wish to train using custom "
|
||||
"captions please install the datasets library: `pip install datasets`. If you wish to load a "
|
||||
"local folder containing images only, specify --instance_data_dir instead."
|
||||
)
|
||||
# Downloading and loading a dataset from the hub.
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# Preprocessing the datasets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# 6. Get the column names for input/target.
|
||||
if args.image_column is None:
|
||||
image_column = column_names[0]
|
||||
logger.info(f"image column defaulting to {image_column}")
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
instance_images = dataset["train"][image_column]
|
||||
|
||||
if args.caption_column is None:
|
||||
logger.info(
|
||||
"No caption column provided, defaulting to instance_prompt for all images. If your dataset "
|
||||
"contains captions/prompts for the images, make sure to specify the "
|
||||
"column as --caption_column"
|
||||
)
|
||||
self.custom_instance_prompts = None
|
||||
else:
|
||||
if args.caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}"
|
||||
)
|
||||
custom_instance_prompts = dataset["train"][args.caption_column]
|
||||
# create final list of captions according to --repeats
|
||||
self.custom_instance_prompts = []
|
||||
for caption in custom_instance_prompts:
|
||||
self.custom_instance_prompts.extend(itertools.repeat(caption, repeats))
|
||||
else:
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
|
||||
instance_images = [Image.open(path) for path in list(Path(instance_data_root).iterdir())]
|
||||
self.custom_instance_prompts = None
|
||||
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
@@ -484,13 +684,23 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
if caption:
|
||||
example["instance_prompt"] = caption
|
||||
else:
|
||||
example["instance_prompt"] = self.instance_prompt
|
||||
|
||||
else: # costum prompts were provided, but length does not match size of image dataset
|
||||
example["instance_prompt"] = self.instance_prompt
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
@@ -498,22 +708,25 @@ class DreamBoothDataset(Dataset):
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt"] = self.class_prompt
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
prompts += [example["class_prompt"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values}
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -732,7 +945,8 @@ def main(args):
|
||||
xformers_version = version.parse(xformers.__version__)
|
||||
if xformers_version == version.parse("0.0.16"):
|
||||
logger.warn(
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, "
|
||||
"please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
|
||||
)
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
@@ -866,35 +1080,119 @@ def main(args):
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
# Optimization parameters
|
||||
unet_lora_parameters_with_lr = {"params": unet_lora_parameters, "lr": args.learning_rate}
|
||||
if args.train_text_encoder:
|
||||
# different learning rate for text encoder and unet
|
||||
text_lora_parameters_one_with_lr = {
|
||||
"params": text_lora_parameters_one,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
text_lora_parameters_two_with_lr = {
|
||||
"params": text_lora_parameters_two,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
unet_lora_parameters_with_lr,
|
||||
text_lora_parameters_one_with_lr,
|
||||
text_lora_parameters_two_with_lr,
|
||||
]
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
params_to_optimize = [unet_lora_parameters_with_lr]
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
if not (args.optimizer.lower() == "prodigy" or args.optimizer.lower() == "adamw"):
|
||||
logger.warn(
|
||||
f"Unsupported choice of optimizer: {args.optimizer}.Supported optimizers include [adamW, prodigy]."
|
||||
"Defaulting to adamW"
|
||||
)
|
||||
args.optimizer = "adamw"
|
||||
|
||||
if args.use_8bit_adam and not args.optimizer.lower() == "adamw":
|
||||
logger.warn(
|
||||
f"use_8bit_adam is ignored when optimizer is not set to 'AdamW'. Optimizer was "
|
||||
f"set to {args.optimizer.lower()}"
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "adamw":
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
try:
|
||||
import prodigyopt
|
||||
except ImportError:
|
||||
raise ImportError("To use Prodigy, please install the prodigyopt library: `pip install prodigyopt`")
|
||||
|
||||
optimizer_class = prodigyopt.Prodigy
|
||||
|
||||
if args.learning_rate <= 0.1:
|
||||
logger.warn(
|
||||
"Learning rate is too low. When using prodigy, it's generally better to set learning rate around 1.0"
|
||||
)
|
||||
if args.train_text_encoder and args.text_encoder_lr:
|
||||
logger.warn(
|
||||
f"Learning rates were provided both for the unet and the text encoder- e.g. text_encoder_lr:"
|
||||
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
|
||||
f"When using prodigy only learning_rate is used as the initial learning rate."
|
||||
)
|
||||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
|
||||
# --learning_rate
|
||||
params_to_optimize[1]["lr"] = args.learning_rate
|
||||
params_to_optimize[2]["lr"] = args.learning_rate
|
||||
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
use_bias_correction=args.prodigy_use_bias_correction,
|
||||
safeguard_warmup=args.prodigy_safeguard_warmup,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_prompt=args.class_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
repeats=args.repeats,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Computes additional embeddings/ids required by the SDXL UNet.
|
||||
# regular text emebddings (when `train_text_encoder` is not True)
|
||||
# regular text embeddings (when `train_text_encoder` is not True)
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
@@ -921,7 +1219,11 @@ def main(args):
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
@@ -934,49 +1236,36 @@ def main(args):
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
# Clear the memory here.
|
||||
if not args.train_text_encoder:
|
||||
# Clear the memory here
|
||||
if not args.train_text_encoder and not train_dataset.custom_instance_prompts:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack the statically computed variables appropriately. This is so that we don't
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
unet_add_text_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
|
||||
if args.with_prior_preservation:
|
||||
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
|
||||
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
unet_add_text_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
# if we're optmizing the text encoder (both if instance prompt is used for all images or custom prompts) we need to tokenize and encode the
|
||||
# batch prompts on all training steps
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
|
||||
if args.with_prior_preservation:
|
||||
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
|
||||
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
@@ -1079,6 +1368,17 @@ def main(args):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds, unet_add_text_embeds = compute_text_embeddings(
|
||||
prompts, text_encoders, tokenizers
|
||||
)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, prompts)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, prompts)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
@@ -1099,16 +1399,21 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation.
|
||||
elems_to_repeat = bsz // 2 if args.with_prior_preservation else bsz
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat, 1),
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
@@ -1116,15 +1421,17 @@ def main(args):
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat, 1)})
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat, 1, 1)
|
||||
unet_added_conditions.update(
|
||||
{"text_embeds": pooled_prompt_embeds.repeat(elems_to_repeat_text_embeds, 1)}
|
||||
)
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds_input, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
@@ -1142,16 +1449,34 @@ def main(args):
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
if args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(noise_scheduler, timesteps)
|
||||
base_weight = (
|
||||
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
|
||||
)
|
||||
|
||||
if noise_scheduler.config.prediction_type == "v_prediction":
|
||||
# Velocity objective needs to be floored to an SNR weight of one.
|
||||
mse_loss_weights = base_weight + 1
|
||||
else:
|
||||
# Epsilon and sample both use the same loss weights.
|
||||
mse_loss_weights = base_weight
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
@@ -1353,7 +1678,8 @@ def main(args):
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
|
||||
@@ -421,6 +421,49 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
)
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_sdxl_custom_captions(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--caption_column text
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
def test_dreambooth_lora_sdxl_text_encoder_custom_captions(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--caption_column text
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--train_text_encoder
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
|
||||
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user