1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130)

* add: LoRA text encoder support for DreamBooth example.

* fix initialization.

* fix: modification call.

* add: entry in the readme.

* use dog dataset from hub.

* fix: params to clip.

* add entry to the LoRA doc.

* add: tests for lora.

* remove unnecessary list comprehension./
This commit is contained in:
Sayak Paul
2023-04-20 17:25:17 +05:30
committed by Daniel Gu
parent f99a9ffab9
commit 206b9b69d1
5 changed files with 196 additions and 31 deletions

View File

@@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.
<frameworkcontent>
<pt>
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
Let's try DreamBooth with a
[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example);
download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
```python
local_dir = "./path_to_training_images"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"

View File

@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
<Tip warning={true}>
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).
</Tip>
@@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \
--push_to_hub
```
It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
### Inference[[dreambooth-inference]]
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:

View File

@@ -45,15 +45,28 @@ write_basic_config()
### Dog toy example
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
And launch the training using
Let's first download it locally:
```python
from huggingface_hub import snapshot_download
local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```
And launch the training using:
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \
@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"
```
@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
### Inference
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"
python train_dreambooth_flax.py \
@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \
```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

View File

@@ -15,6 +15,7 @@
import argparse
import hashlib
import itertools
import logging
import math
import os
@@ -43,12 +44,13 @@ from diffusers import (
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
@@ -58,7 +60,7 @@ check_min_version("0.16.0.dev0")
logger = get_logger(__name__)
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@@ -83,6 +85,8 @@ inference: true
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}
LoRA for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -219,6 +223,11 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
@@ -547,7 +556,13 @@ def main(args):
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -691,7 +706,7 @@ def main(args):
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
unet_lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
@@ -703,12 +718,33 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
unet_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(unet_lora_layers)
accelerator.register_for_checkpointing(lora_layers)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
accelerator.register_for_checkpointing(unet_lora_layers)
del temp_pipeline
if args.scale_lr:
args.learning_rate = (
@@ -739,8 +775,13 @@ def main(args):
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
optimizer = optimizer_class(
lora_layers.parameters(),
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -784,9 +825,14 @@ def main(args):
)
# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
)
else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -845,6 +891,8 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -900,7 +948,11 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -914,7 +966,14 @@ def main(args):
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
# We combine the text encoder and UNet LoRA parameters with a simple
# custom logic. `accelerator.save_state()` won't know that. So,
# use `LoraLoaderMixin.save_lora_weights()`.
LoraLoaderMixin.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -970,7 +1029,12 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
text_encoder = text_encoder.to(torch.float32)
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
# Final inference
# Load previous pipeline
@@ -981,7 +1045,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
pipeline.load_attn_procs(args.output_dir)
# run inference
if args.validation_prompt and args.num_validation_images > 0:
@@ -1010,6 +1074,7 @@ def main(args):
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)

View File

@@ -23,6 +23,7 @@ import tempfile
import unittest
from typing import List
import torch
from accelerate.utils import write_basic_config
from diffusers import DiffusionPipeline, UNet2DConditionModel
@@ -221,6 +222,68 @@ class ExamplesTestsAccelerate(unittest.TestCase):
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
def test_dreambooth_lora(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` in their names.
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--train_text_encoder
--output_dir {tmpdir}
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# the names of the keys of the state dict should either start with `unet`
# or `text_encoder`.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
keys = lora_state_dict.keys()
is_correct_naming = all(k.startswith("unet") or k.startswith("text_encoder") for k in keys)
self.assertTrue(is_correct_naming)
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""