1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[SDXL DreamBooth LoRA] add support for text encoder fine-tuning (#4097)

* Allow low precision sd xl

* finish

* finish

* feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth

* fix: variable assignments.

* add: autocast block.

* add debugging

* vae dtype hell

* fix: vae dtype hell.

* fix: vae dtype hell 3.

* clean up

* lora text encoder loader.

* fix: unwrapping models.

* add: tests.

* docs.

* handle unexpected keys.

* fix vae dtype in the final inference.

* fix scope problem.

* fix: save_model_card args.

* initialize: prefix to None.

* fix: dtype issues.

* apply gixes.

* debgging.

* debugging

* debugging

* debugging

* debugging

* debugging

* add: fast tests.

* pre-tokenize.

* address: will's comments.

* fix: loader and tests.

* fix: dataloader.

* simplify dataloader.

* length.

* simplification.

* make style && make quality

* simplify state_dict munging

* fix: tests.

* fix: state_dict packing.

* Apply suggestions from code review

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

---------

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Sayak Paul
2023-07-25 05:35:48 +05:30
committed by GitHub
parent fed12376c5
commit 365e8461ac
6 changed files with 565 additions and 170 deletions

View File

@@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output
|---|---|
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/sks_dog.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_sks_dog.png) |
### Training with text encoder(s)
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
* SDXL has two text encoders. So, we fine-tune both using LoRA.
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
### Specifying a better VAE
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
## Notes
In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗

View File

@@ -16,6 +16,7 @@
import argparse
import gc
import hashlib
import itertools
import logging
import math
import os
@@ -45,11 +46,11 @@ import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
@@ -63,12 +64,7 @@ logger = get_logger(__name__)
def save_model_card(
repo_id: str,
images=None,
base_model=str,
train_text_encoder=False,
prompt=str,
repo_folder=None,
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
):
img_str = ""
for i, image in enumerate(images):
@@ -96,6 +92,8 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
{img_str}
LoRA for the text encoder was enabled: {train_text_encoder}.
Special VAE used for training: {vae_path}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
@@ -130,6 +128,12 @@ def parse_args(input_args=None):
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_vae_model_name_or_path",
type=str,
default=None,
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
)
parser.add_argument(
"--revision",
type=str,
@@ -420,38 +424,25 @@ def parse_args(input_args=None):
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
if args.train_text_encoder and args.pre_compute_text_embeddings:
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
It pre-processes the images.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
class_data_root=None,
class_prompt=None,
class_num=None,
size=1024,
center_crop=False,
instance_prompt_hidden_states=None,
class_prompt_hidden_states=None,
instance_unet_added_conditions=None,
class_unet_added_conditions=None,
):
self.size = size
self.center_crop = center_crop
self.instance_prompt_hidden_states = instance_prompt_hidden_states
self.class_prompt_hidden_states = class_prompt_hidden_states
self.instance_unet_added_conditions = instance_unet_added_conditions
self.class_unet_added_conditions = class_unet_added_conditions
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
@@ -459,7 +450,6 @@ class DreamBoothDataset(Dataset):
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
@@ -471,7 +461,6 @@ class DreamBoothDataset(Dataset):
else:
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
@@ -496,9 +485,6 @@ class DreamBoothDataset(Dataset):
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.instance_prompt_hidden_states
example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
class_image = exif_transpose(class_image)
@@ -506,49 +492,22 @@ class DreamBoothDataset(Dataset):
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.class_prompt_hidden_states
example["class_added_cond_kwargs"] = self.class_unet_added_conditions
return example
def collate_fn(examples, with_prior_preservation=False):
has_attention_mask = "instance_attention_mask" in examples[0]
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples]
add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples]
if has_attention_mask:
attention_mask = [example["instance_attention_mask"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples]
add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples]
if has_attention_mask:
attention_mask += [example["class_attention_mask"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
add_text_embeds = torch.cat(add_text_embeds, dim=0)
add_time_ids = torch.cat(add_time_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
"unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
}
if has_attention_mask:
batch["attention_mask"] = attention_mask
batch = {"pixel_values": pixel_values}
return batch
@@ -569,27 +528,29 @@ class PromptDataset(Dataset):
return example
def tokenize_prompt(tokenizer, prompt):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
return text_input_ids
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
def encode_prompt(text_encoders, tokenizers, prompt):
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
prompt_embeds_list = []
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {tokenizer.model_max_length} tokens: {removed_text}"
)
for i, text_encoder in enumerate(text_encoders):
if tokenizers is not None:
tokenizer = tokenizers[i]
text_input_ids = tokenize_prompt(tokenizer, prompt)
else:
assert text_input_ids_list is not None
text_input_ids = text_input_ids_list[i]
prompt_embeds = text_encoder(
text_input_ids.to(text_encoder.device),
@@ -641,9 +602,6 @@ def main(args):
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
if args.train_text_encoder:
raise NotImplementedError("Text encoder training not yet supported.")
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -677,7 +635,7 @@ def main(args):
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
@@ -742,7 +700,14 @@ def main(args):
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
vae_path = (
args.pretrained_model_name_or_path
if args.pretrained_vae_model_name_or_path is None
else args.pretrained_vae_model_name_or_path
)
vae = AutoencoderKL.from_pretrained(
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
@@ -764,7 +729,10 @@ def main(args):
# Move unet, vae and text_encoder to device and cast to weight_dtype
# The VAE is in float32 to avoid NaN losses.
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=torch.float32)
if args.pretrained_vae_model_name_or_path is None:
vae.to(accelerator.device, dtype=torch.float32)
else:
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
@@ -804,42 +772,66 @@ def main(args):
unet_lora_parameters.extend(module.parameters())
unet.set_attn_processor(unet_lora_attn_procs)
# unet_lora_layers = AttnProcsLayers(unet.attn_processors)
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks.
if args.train_text_encoder:
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32)
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32)
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
# there are only two options here. Either are just the unet attn processor layers
# or there are the unet and text encoder atten layers
unet_lora_layers_to_save = None
text_encoder_one_lora_layers_to_save = None
text_encoder_two_lora_layers_to_save = None
for model in models:
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
LoraLoaderMixin.save_lora_weights(
StableDiffusionXLPipeline.save_lora_weights(
output_dir,
unet_lora_layers=unet_lora_layers_to_save,
text_encoder_lora_layers=None,
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
)
def load_model_hook(models, input_dir):
unet_ = None
text_encoder_ = None
text_encoder_one_ = None
text_encoder_two_ = None
while len(models) > 0:
model = models.pop()
if isinstance(model, type(accelerator.unwrap_model(unet))):
unet_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
)
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
)
accelerator.register_save_state_pre_hook(save_model_hook)
@@ -869,7 +861,11 @@ def main(args):
optimizer_class = torch.optim.AdamW
# Optimizer creation
params_to_optimize = unet_lora_parameters
params_to_optimize = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
@@ -878,62 +874,81 @@ def main(args):
eps=args.adam_epsilon,
)
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
# UNet as the model is already big and it uses two text encoders.
# TODO: when we add support for text encoder training, will reivist.
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
# Computes additional embeddings/ids required by the SDXL UNet.
# regular text emebddings (when `train_text_encoder` is not True)
# pooled text embeddings
# time ids
# Here, we compute not just the text embeddings but also the additional embeddings
# needed for the SD XL UNet to operate.
def compute_embeddings(prompt, text_encoders, tokenizers):
def compute_time_ids():
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
original_size = (args.resolution, args.resolution)
target_size = (args.resolution, args.resolution)
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
return add_time_ids
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
add_text_embeds = pooled_prompt_embeds
if not args.train_text_encoder:
tokenizers = [tokenizer_one, tokenizer_two]
text_encoders = [text_encoder_one, text_encoder_two]
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
add_time_ids = list(original_size + crops_coords_top_left + target_size)
add_time_ids = torch.tensor([add_time_ids])
def compute_text_embeddings(prompt, text_encoders, tokenizers):
with torch.no_grad():
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
prompt_embeds = prompt_embeds.to(accelerator.device)
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
return prompt_embeds, pooled_prompt_embeds
prompt_embeds = prompt_embeds.to(accelerator.device)
add_text_embeds = add_text_embeds.to(accelerator.device)
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
return prompt_embeds, unet_added_cond_kwargs
instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
class_prompt_hidden_states, class_unet_added_conditions = None, None
if args.with_prior_preservation:
class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings(
args.class_prompt, text_encoders, tokenizers
# Handle instance prompt.
instance_time_ids = compute_time_ids()
if not args.train_text_encoder:
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
args.instance_prompt, text_encoders, tokenizers
)
del tokenizers, text_encoders
# Handle class prompt for prior-preservation.
if args.with_prior_preservation:
class_time_ids = compute_time_ids()
if not args.train_text_encoder:
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
args.class_prompt, text_encoders, tokenizers
)
gc.collect()
torch.cuda.empty_cache()
# Clear the memory here.
if not args.train_text_encoder:
del tokenizers, text_encoders
gc.collect()
torch.cuda.empty_cache()
# Pack the statically computed variables appropriately. This is so that we don't
# have to pass them to the dataloader.
add_time_ids = instance_time_ids
if args.with_prior_preservation:
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
if not args.train_text_encoder:
prompt_embeds = instance_prompt_hidden_states
unet_add_text_embeds = instance_pooled_prompt_embeds
if args.with_prior_preservation:
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
else:
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
if args.with_prior_preservation:
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
class_num=args.num_class_images,
size=args.resolution,
center_crop=args.center_crop,
instance_prompt_hidden_states=instance_prompt_hidden_states,
class_prompt_hidden_states=class_prompt_hidden_states,
instance_unet_added_conditions=instance_unet_added_conditions,
class_unet_added_conditions=class_unet_added_conditions,
)
train_dataloader = torch.utils.data.DataLoader(
@@ -954,16 +969,21 @@ def main(args):
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
if args.train_text_encoder:
unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
)
else:
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -1022,6 +1042,9 @@ def main(args):
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder_one.train()
text_encoder_two.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -1030,12 +1053,16 @@ def main(args):
continue
with accelerator.accumulate(unet):
# pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
if args.pretrained_vae_model_name_or_path is None:
pixel_values = batch["pixel_values"]
else:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
# Convert images to latent space
model_input = vae.encode(batch["pixel_values"]).latent_dist.sample()
model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
model_input = model_input.to(weight_dtype)
if args.pretrained_vae_model_name_or_path is None:
model_input = model_input.to(weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(model_input)
@@ -1051,9 +1078,30 @@ def main(args):
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
# Predict the noise residual
model_pred = unet(
noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"]
).sample
if not args.train_text_encoder:
unet_added_conditions = {
"time_ids": add_time_ids.repeat(bsz, 1),
"text_embeds": unet_add_text_embeds.repeat(bsz, 1),
}
model_pred = unet(
noisy_model_input,
timesteps,
prompt_embeds.repeat(bsz, 1, 1),
added_cond_kwargs=unet_added_conditions,
).sample
else:
unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)}
prompt_embeds, pooled_prompt_embeds = encode_prompt(
text_encoders=[text_encoder_one, text_encoder_two],
tokenizers=None,
prompt=None,
text_input_ids_list=[tokens_one, tokens_two],
)
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)})
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
model_pred = unet(
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
@@ -1081,7 +1129,11 @@ def main(args):
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = unet_lora_parameters
params_to_clip = (
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
if args.train_text_encoder
else unet_lora_parameters
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
@@ -1132,8 +1184,22 @@ def main(args):
f" {args.validation_prompt}."
)
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=accelerator.unwrap_model(text_encoder_one)
if args.train_text_encoder
else text_encoder_one,
text_encoder_2=accelerator.unwrap_model(text_encoder_two)
if args.train_text_encoder
else text_encoder_two,
unet=accelerator.unwrap_model(unet),
revision=args.revision,
torch_dtype=weight_dtype,
@@ -1161,9 +1227,11 @@ def main(args):
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
pipeline_args = {"prompt": args.validation_prompt}
images = [
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
]
with torch.cuda.amp.autocast():
images = [
pipeline(**pipeline_args, generator=generator).images[0]
for _ in range(args.num_validation_images)
]
for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
@@ -1189,16 +1257,32 @@ def main(args):
unet = unet.to(torch.float32)
unet_lora_layers = unet_attn_processors_state_dict(unet)
LoraLoaderMixin.save_lora_weights(
if args.train_text_encoder:
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
else:
text_encoder_lora_layers = None
text_encoder_2_lora_layers = None
StableDiffusionXLPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=None,
text_encoder_lora_layers=text_encoder_lora_layers,
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
vae = AutoencoderKL.from_pretrained(
vae_path,
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
)
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
@@ -1250,6 +1334,7 @@ def main(args):
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
vae_path=args.pretrained_vae_model_name_or_path,
)
upload_folder(
repo_id=repo_id,

View File

@@ -385,6 +385,42 @@ class ExamplesTestsAccelerate(unittest.TestCase):
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
self.assertTrue(starts_with_unet)
def test_dreambooth_lora_sdxl_with_text_encoder(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""
examples/dreambooth/train_dreambooth_lora_sdxl.py
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
--instance_data_dir docs/source/en/imgs
--instance_prompt photo
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 2
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--train_text_encoder
""".split()
run_command(self._launch_args + test_args)
# save_pretrained smoke test
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
# make sure the state_dict has the correct naming in the parameters.
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
is_lora = all("lora" in k for k in lora_state_dict.keys())
self.assertTrue(is_lora)
# when not training the text encoder, all the parameters in the state dict should start
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
keys = lora_state_dict.keys()
starts_with_unet = all(
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
)
self.assertTrue(starts_with_unet)
def test_custom_diffusion(self):
with tempfile.TemporaryDirectory() as tmpdir:
test_args = f"""

View File

@@ -59,7 +59,7 @@ if is_safetensors_available():
import safetensors
if is_transformers_available():
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
if is_accelerate_available():
from accelerate import init_empty_weights
@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
def text_encoder_attn_modules(text_encoder):
attn_modules = []
if isinstance(text_encoder, CLIPTextModel):
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
name = f"text_model.encoder.layers.{i}.self_attn"
mod = layer.self_attn
@@ -1016,18 +1016,20 @@ class LoraLoaderMixin:
warnings.warn(warn_message)
@classmethod
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
"""
This will load the LoRA layers specified in `state_dict` into `text_encoder`
Parameters:
state_dict (`dict`):
A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
A standard state dict containing the lora layer parameters. The key should be prefixed with an
additional `text_encoder` to distinguish between unet lora layers.
network_alpha (`float`):
See `LoRALinearLayer` for more details.
text_encoder (`CLIPTextModel`):
The text encoder model to load the LoRA layers into.
prefix (`str`):
Expected prefix of the `text_encoder` in the `state_dict`.
lora_scale (`float`):
How much to scale the output of the lora linear layer before it is added with the output of the regular
lora layer.
@@ -1037,14 +1039,16 @@ class LoraLoaderMixin:
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
# their prefixes.
keys = list(state_dict.keys())
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
prefix = cls.text_encoder_name if prefix is None else prefix
if any(cls.text_encoder_name in key for key in keys):
# Load the layers corresponding to text encoder and make necessary adjustments.
text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
text_encoder_lora_state_dict = {
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
}
if len(text_encoder_lora_state_dict) > 0:
logger.info(f"Loading {cls.text_encoder_name}.")
logger.info(f"Loading {prefix}.")
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
# Convert from the old naming convention to the new naming convention.
@@ -1184,23 +1188,10 @@ class LoraLoaderMixin:
replace `torch.save` with another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
# Create a flat dictionary.
state_dict = {}
# Populate the dictionary.
if unet_lora_layers is not None:
weights = (
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
@@ -1222,6 +1213,38 @@ class LoraLoaderMixin:
state_dict.update(text_encoder_lora_state_dict)
# Save the model
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def write_lora_layers(
state_dict: Dict[str, torch.Tensor],
save_directory: str,
is_main_process: bool,
weight_name: str,
save_function: Callable,
safe_serialization: bool,
):
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
if safe_serialization:
def save_function(weights, filename):
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
else:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
if weight_name is None:
if safe_serialization:
weight_name = LORA_WEIGHT_NAME_SAFE

View File

@@ -13,6 +13,7 @@
# limitations under the License.
import inspect
import os
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
@@ -841,3 +842,66 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
return (image,)
return StableDiffusionXLPipelineOutput(images=image)
# Overrride to properly handle the loading and unloading of the additional text encoder.
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
if len(text_encoder_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder,
prefix="text_encoder",
lora_scale=self.lora_scale,
)
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
if len(text_encoder_2_state_dict) > 0:
self.load_lora_into_text_encoder(
text_encoder_2_state_dict,
network_alpha=network_alpha,
text_encoder=self.text_encoder_2,
prefix="text_encoder_2",
lora_scale=self.lora_scale,
)
@classmethod
def save_lora_weights(
self,
save_directory: Union[str, os.PathLike],
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
is_main_process: bool = True,
weight_name: str = None,
save_function: Callable = None,
safe_serialization: bool = False,
):
state_dict = {}
def pack_weights(layers, prefix):
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
return layers_state_dict
state_dict.update(pack_weights(unet_lora_layers, "unet"))
if text_encoder_lora_layers and text_encoder_2_lora_layers:
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
self.write_lora_layers(
state_dict=state_dict,
save_directory=save_directory,
is_main_process=is_main_process,
weight_name=weight_name,
save_function=save_function,
safe_serialization=safe_serialization,
)
def _remove_text_encoder_monkey_patch(self):
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)

View File

@@ -21,9 +21,16 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub.repocard import RepoCard
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers import (
AutoencoderKL,
DDIMScheduler,
EulerDiscreteScheduler,
StableDiffusionPipeline,
StableDiffusionXLPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
from diffusers.models.attention_processor import (
Attention,
@@ -399,7 +406,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
)
self.assertIsInstance(module.processor, attn_proc_class)
def test_unload_lora(self):
def test_unload_lora_sd(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionPipeline(**pipeline_components)
@@ -503,6 +510,175 @@ class LoraLoaderMixinTests(unittest.TestCase):
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
class SDXLLoraLoaderMixinTests(unittest.TestCase):
def get_dummy_components(self):
torch.manual_seed(0)
unet = UNet2DConditionModel(
block_out_channels=(32, 64),
layers_per_block=2,
sample_size=32,
in_channels=4,
out_channels=4,
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
# SD2-specific config below
attention_head_dim=(2, 4),
use_linear_projection=True,
addition_embed_type="text_time",
addition_time_embed_dim=8,
transformer_layers_per_block=(1, 2),
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
cross_attention_dim=64,
)
scheduler = EulerDiscreteScheduler(
beta_start=0.00085,
beta_end=0.012,
steps_offset=1,
beta_schedule="scaled_linear",
timestep_spacing="leading",
)
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
)
torch.manual_seed(0)
text_encoder_config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
# SD2-specific config below
hidden_act="gelu",
projection_dim=32,
)
text_encoder = CLIPTextModel(text_encoder_config)
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2)
pipeline_components = {
"unet": unet,
"scheduler": scheduler,
"vae": vae,
"text_encoder": text_encoder,
"text_encoder_2": text_encoder_2,
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
}
lora_components = {
"unet_lora_layers": unet_lora_layers,
"text_encoder_one_lora_layers": text_encoder_one_lora_layers,
"text_encoder_two_lora_layers": text_encoder_two_lora_layers,
"unet_lora_attn_procs": unet_lora_attn_procs,
}
return pipeline_components, lora_components
def get_dummy_inputs(self, with_generator=True):
batch_size = 1
sequence_length = 10
num_channels = 4
sizes = (32, 32)
generator = torch.manual_seed(0)
noise = floats_tensor((batch_size, num_channels) + sizes)
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
pipeline_inputs = {
"prompt": "A painting of a squirrel eating a burger",
"num_inference_steps": 2,
"guidance_scale": 6.0,
"output_type": "np",
}
if with_generator:
pipeline_inputs.update({"generator": generator})
return noise, input_ids, pipeline_inputs
def test_lora_save_load(self):
pipeline_components, lora_components = self.get_dummy_components()
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)
_, _, pipeline_inputs = self.get_dummy_inputs()
original_images = sd_pipe(**pipeline_inputs).images
orig_image_slice = original_images[0, -3:, -3:, -1]
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Outputs shouldn't match.
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
def test_unload_lora_sdxl(self):
pipeline_components, lora_components = self.get_dummy_components()
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice = original_images[0, -3:, -3:, -1]
# Emulate training.
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
with tempfile.TemporaryDirectory() as tmpdirname:
StableDiffusionXLPipeline.save_lora_weights(
save_directory=tmpdirname,
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(tmpdirname)
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
lora_image_slice = lora_images[0, -3:, -3:, -1]
# Unload LoRA parameters.
sd_pipe.unload_lora_weights()
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
assert not np.allclose(
orig_image_slice, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert not np.allclose(
orig_image_slice_two, lora_image_slice
), "LoRA parameters should lead to a different image slice."
assert np.allclose(
orig_image_slice, orig_image_slice_two, atol=1e-3
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
@slow
@require_torch_gpu
class LoraIntegrationTests(unittest.TestCase):