mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[dreambooth] allow fine-tuning text encoder (#883)
* allow fine-tuning text encoder * fix a few things * update readme
This commit is contained in:
@@ -160,6 +160,39 @@ accelerate launch train_dreambooth.py \
|
||||
--mixed_precision=fp16
|
||||
```
|
||||
|
||||
### Fine-tune text encoder with the UNet.
|
||||
|
||||
The script also allows to fine-tune the `text_encoder` along with the `unet`. It's been observed experimentally that fine-tuning `text_encoder` gives much better results especially on faces.
|
||||
Pass the `--train_text_encoder` argument to the script to enable training `text_encoder`.
|
||||
|
||||
___Note: Training text encoder requires more memory, with this option the training won't fit on 16GB GPU. It needs at least 24GB VRAM.___
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--train_text_encoder \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--use_8bit_adam
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=2e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt.
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import argparse
|
||||
import itertools
|
||||
import math
|
||||
import os
|
||||
from pathlib import Path
|
||||
@@ -100,6 +101,7 @@ def parse_args():
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
@@ -320,6 +322,15 @@ def main():
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
|
||||
raise ValueError(
|
||||
"Gradient accumulation is not supported when training the text encoder in distributed training. "
|
||||
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
|
||||
)
|
||||
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
@@ -385,8 +396,14 @@ def main():
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
||||
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
||||
|
||||
vae.requires_grad_(False)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.requires_grad_(False)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
if args.train_text_encoder:
|
||||
text_encoder.gradient_checkpointing_enable()
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
@@ -406,8 +423,11 @@ def main():
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters()) if args.train_text_encoder else unet.parameters()
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
unet.parameters(), # only optimize unet
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
@@ -467,9 +487,14 @@ def main():
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
@@ -480,8 +505,9 @@ def main():
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
if not args.train_text_encoder:
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -516,9 +542,8 @@ def main():
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
@@ -532,8 +557,7 @@ def main():
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
with torch.no_grad():
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
@@ -556,7 +580,12 @@ def main():
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
|
||||
params_to_clip = (
|
||||
itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||
if args.train_text_encoder
|
||||
else unet.parameters()
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
@@ -578,7 +607,9 @@ def main():
|
||||
# Create the pipeline using using the trained modules and save it.
|
||||
if accelerator.is_main_process:
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, unet=accelerator.unwrap_model(unet)
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user