mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[SDXL DreamBooth LoRA] add support for text encoder fine-tuning (#4097)
* Allow low precision sd xl * finish * finish * feat: initial draft for supporting text encoder lora finetuning for SDXL DreamBooth * fix: variable assignments. * add: autocast block. * add debugging * vae dtype hell * fix: vae dtype hell. * fix: vae dtype hell 3. * clean up * lora text encoder loader. * fix: unwrapping models. * add: tests. * docs. * handle unexpected keys. * fix vae dtype in the final inference. * fix scope problem. * fix: save_model_card args. * initialize: prefix to None. * fix: dtype issues. * apply gixes. * debgging. * debugging * debugging * debugging * debugging * debugging * add: fast tests. * pre-tokenize. * address: will's comments. * fix: loader and tests. * fix: dataloader. * simplify dataloader. * length. * simplification. * make style && make quality * simplify state_dict munging * fix: tests. * fix: state_dict packing. * Apply suggestions from code review Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -164,6 +164,17 @@ Here's a side-by-side comparison of the with and without Refiner pipeline output
|
||||
|---|---|
|
||||
|  |  |
|
||||
|
||||
### Training with text encoder(s)
|
||||
|
||||
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:
|
||||
|
||||
* SDXL has two text encoders. So, we fine-tune both using LoRA.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memory.
|
||||
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
## Notes
|
||||
|
||||
In our experiments we found that SDXL yields very good initial results using the default settings of the script. We didn't explore further hyper-parameter tuning experiments, but we do encourage the community to explore this avenue further and share their results with us 🤗
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
import argparse
|
||||
import gc
|
||||
import hashlib
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
@@ -45,11 +46,11 @@ import diffusers
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.loaders import LoraLoaderMixin, text_encoder_lora_state_dict
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor, LoRAAttnProcessor2_0
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
@@ -63,12 +64,7 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
prompt=str,
|
||||
repo_folder=None,
|
||||
repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None, vae_path=None
|
||||
):
|
||||
img_str = ""
|
||||
for i, image in enumerate(images):
|
||||
@@ -96,6 +92,8 @@ These are LoRA adaption weights for {base_model}. The weights were trained on {p
|
||||
{img_str}
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
"""
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
@@ -130,6 +128,12 @@ def parse_args(input_args=None):
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_vae_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to pretrained VAE model with better numerical stability. More details: https://github.com/huggingface/diffusers/pull/4038.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
@@ -420,38 +424,25 @@ def parse_args(input_args=None):
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
if args.train_text_encoder and args.pre_compute_text_embeddings:
|
||||
raise ValueError("`--train_text_encoder` cannot be used with `--pre_compute_text_embeddings`")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
It pre-processes the images.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
class_num=None,
|
||||
size=1024,
|
||||
center_crop=False,
|
||||
instance_prompt_hidden_states=None,
|
||||
class_prompt_hidden_states=None,
|
||||
instance_unet_added_conditions=None,
|
||||
class_unet_added_conditions=None,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.instance_prompt_hidden_states = instance_prompt_hidden_states
|
||||
self.class_prompt_hidden_states = class_prompt_hidden_states
|
||||
self.instance_unet_added_conditions = instance_unet_added_conditions
|
||||
self.class_unet_added_conditions = class_unet_added_conditions
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
@@ -459,7 +450,6 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
self.instance_prompt = instance_prompt
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
@@ -471,7 +461,6 @@ class DreamBoothDataset(Dataset):
|
||||
else:
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self._length = max(self.num_class_images, self.num_instance_images)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
@@ -496,9 +485,6 @@ class DreamBoothDataset(Dataset):
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
|
||||
example["instance_prompt_ids"] = self.instance_prompt_hidden_states
|
||||
example["instance_added_cond_kwargs"] = self.instance_unet_added_conditions
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
class_image = exif_transpose(class_image)
|
||||
@@ -506,49 +492,22 @@ class DreamBoothDataset(Dataset):
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.class_prompt_hidden_states
|
||||
example["class_added_cond_kwargs"] = self.class_unet_added_conditions
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
has_attention_mask = "instance_attention_mask" in examples[0]
|
||||
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
add_text_embeds = [example["instance_added_cond_kwargs"]["text_embeds"] for example in examples]
|
||||
add_time_ids = [example["instance_added_cond_kwargs"]["time_ids"] for example in examples]
|
||||
if has_attention_mask:
|
||||
attention_mask = [example["instance_attention_mask"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
add_text_embeds += [example["class_added_cond_kwargs"]["text_embeds"] for example in examples]
|
||||
add_time_ids += [example["class_added_cond_kwargs"]["time_ids"] for example in examples]
|
||||
|
||||
if has_attention_mask:
|
||||
attention_mask += [example["class_attention_mask"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = torch.cat(input_ids, dim=0)
|
||||
add_text_embeds = torch.cat(add_text_embeds, dim=0)
|
||||
add_time_ids = torch.cat(add_time_ids, dim=0)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
"unet_added_conditions": {"text_embeds": add_text_embeds, "time_ids": add_time_ids},
|
||||
}
|
||||
|
||||
if has_attention_mask:
|
||||
batch["attention_mask"] = attention_mask
|
||||
|
||||
batch = {"pixel_values": pixel_values}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -569,27 +528,29 @@ class PromptDataset(Dataset):
|
||||
return example
|
||||
|
||||
|
||||
def tokenize_prompt(tokenizer, prompt):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
return text_input_ids
|
||||
|
||||
|
||||
# Adapted from pipelines.StableDiffusionXLPipeline.encode_prompt
|
||||
def encode_prompt(text_encoders, tokenizers, prompt):
|
||||
def encode_prompt(text_encoders, tokenizers, prompt, text_input_ids_list=None):
|
||||
prompt_embeds_list = []
|
||||
|
||||
for tokenizer, text_encoder in zip(tokenizers, text_encoders):
|
||||
text_inputs = tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
||||
|
||||
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
|
||||
removed_text = tokenizer.batch_decode(untruncated_ids[:, tokenizer.model_max_length - 1 : -1])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
for i, text_encoder in enumerate(text_encoders):
|
||||
if tokenizers is not None:
|
||||
tokenizer = tokenizers[i]
|
||||
text_input_ids = tokenize_prompt(tokenizer, prompt)
|
||||
else:
|
||||
assert text_input_ids_list is not None
|
||||
text_input_ids = text_input_ids_list[i]
|
||||
|
||||
prompt_embeds = text_encoder(
|
||||
text_input_ids.to(text_encoder.device),
|
||||
@@ -641,9 +602,6 @@ def main(args):
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
if args.train_text_encoder:
|
||||
raise NotImplementedError("Text encoder training not yet supported.")
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
@@ -677,7 +635,7 @@ def main(args):
|
||||
torch_dtype = torch.float16
|
||||
elif args.prior_generation_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
@@ -742,7 +700,14 @@ def main(args):
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
vae_path = (
|
||||
args.pretrained_model_name_or_path
|
||||
if args.pretrained_vae_model_name_or_path is None
|
||||
else args.pretrained_vae_model_name_or_path
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path, subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None, revision=args.revision
|
||||
)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
@@ -764,7 +729,10 @@ def main(args):
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
# The VAE is in float32 to avoid NaN losses.
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
vae.to(accelerator.device, dtype=torch.float32)
|
||||
else:
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_one.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder_two.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
@@ -804,42 +772,66 @@ def main(args):
|
||||
unet_lora_parameters.extend(module.parameters())
|
||||
|
||||
unet.set_attn_processor(unet_lora_attn_procs)
|
||||
# unet_lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
|
||||
# So, instead, we monkey-patch the forward calls of its attention-blocks.
|
||||
if args.train_text_encoder:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
text_lora_parameters_one = LoraLoaderMixin._modify_text_encoder(text_encoder_one, dtype=torch.float32)
|
||||
text_lora_parameters_two = LoraLoaderMixin._modify_text_encoder(text_encoder_two, dtype=torch.float32)
|
||||
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
# there are only two options here. Either are just the unet attn processor layers
|
||||
# or there are the unet and text encoder atten layers
|
||||
unet_lora_layers_to_save = None
|
||||
text_encoder_one_lora_layers_to_save = None
|
||||
text_encoder_two_lora_layers_to_save = None
|
||||
|
||||
for model in models:
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_lora_layers_to_save = unet_attn_processors_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_lora_layers_to_save = text_encoder_lora_state_dict(model)
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
output_dir,
|
||||
unet_lora_layers=unet_lora_layers_to_save,
|
||||
text_encoder_lora_layers=None,
|
||||
text_encoder_lora_layers=text_encoder_one_lora_layers_to_save,
|
||||
text_encoder_2_lora_layers=text_encoder_two_lora_layers_to_save,
|
||||
)
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
unet_ = None
|
||||
text_encoder_ = None
|
||||
text_encoder_one_ = None
|
||||
text_encoder_two_ = None
|
||||
|
||||
while len(models) > 0:
|
||||
model = models.pop()
|
||||
|
||||
if isinstance(model, type(accelerator.unwrap_model(unet))):
|
||||
unet_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_one))):
|
||||
text_encoder_one_ = model
|
||||
elif isinstance(model, type(accelerator.unwrap_model(text_encoder_two))):
|
||||
text_encoder_two_ = model
|
||||
else:
|
||||
raise ValueError(f"unexpected save model: {model.__class__}")
|
||||
|
||||
lora_state_dict, network_alpha = LoraLoaderMixin.lora_state_dict(input_dir)
|
||||
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alpha=network_alpha, unet=unet_)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_one_
|
||||
)
|
||||
LoraLoaderMixin.load_lora_into_text_encoder(
|
||||
lora_state_dict, network_alpha=network_alpha, text_encoder=text_encoder_two_
|
||||
)
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
@@ -869,7 +861,11 @@ def main(args):
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
params_to_optimize = unet_lora_parameters
|
||||
params_to_optimize = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
)
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
@@ -878,62 +874,81 @@ def main(args):
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# We ALWAYS pre-compute the additional condition embeddings needed for SDXL
|
||||
# UNet as the model is already big and it uses two text encoders.
|
||||
# TODO: when we add support for text encoder training, will reivist.
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
# Computes additional embeddings/ids required by the SDXL UNet.
|
||||
# regular text emebddings (when `train_text_encoder` is not True)
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
# Here, we compute not just the text embeddings but also the additional embeddings
|
||||
# needed for the SD XL UNet to operate.
|
||||
def compute_embeddings(prompt, text_encoders, tokenizers):
|
||||
def compute_time_ids():
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
return add_time_ids
|
||||
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
add_text_embeds = pooled_prompt_embeds
|
||||
if not args.train_text_encoder:
|
||||
tokenizers = [tokenizer_one, tokenizer_two]
|
||||
text_encoders = [text_encoder_one, text_encoder_two]
|
||||
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
def compute_text_embeddings(prompt, text_encoders, tokenizers):
|
||||
with torch.no_grad():
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(text_encoders, tokenizers, prompt)
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
prompt_embeds = prompt_embeds.to(accelerator.device)
|
||||
add_text_embeds = add_text_embeds.to(accelerator.device)
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=prompt_embeds.dtype)
|
||||
unet_added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
|
||||
|
||||
return prompt_embeds, unet_added_cond_kwargs
|
||||
|
||||
instance_prompt_hidden_states, instance_unet_added_conditions = compute_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
class_prompt_hidden_states, class_unet_added_conditions = None, None
|
||||
if args.with_prior_preservation:
|
||||
class_prompt_hidden_states, class_unet_added_conditions = compute_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
instance_prompt_hidden_states, instance_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.instance_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
del tokenizers, text_encoders
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
)
|
||||
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
# Clear the memory here.
|
||||
if not args.train_text_encoder:
|
||||
del tokenizers, text_encoders
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Pack the statically computed variables appropriately. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not args.train_text_encoder:
|
||||
prompt_embeds = instance_prompt_hidden_states
|
||||
unet_add_text_embeds = instance_pooled_prompt_embeds
|
||||
if args.with_prior_preservation:
|
||||
prompt_embeds = torch.cat([prompt_embeds, class_prompt_hidden_states], dim=0)
|
||||
unet_add_text_embeds = torch.cat([unet_add_text_embeds, class_pooled_prompt_embeds], dim=0)
|
||||
else:
|
||||
tokens_one = tokenize_prompt(tokenizer_one, args.instance_prompt)
|
||||
tokens_two = tokenize_prompt(tokenizer_two, args.instance_prompt)
|
||||
if args.with_prior_preservation:
|
||||
class_tokens_one = tokenize_prompt(tokenizer_one, args.class_prompt)
|
||||
class_tokens_two = tokenize_prompt(tokenizer_two, args.class_prompt)
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
class_num=args.num_class_images,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
instance_prompt_hidden_states=instance_prompt_hidden_states,
|
||||
class_prompt_hidden_states=class_prompt_hidden_states,
|
||||
instance_unet_added_conditions=instance_unet_added_conditions,
|
||||
class_unet_added_conditions=class_unet_added_conditions,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
@@ -954,16 +969,21 @@ def main(args):
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
|
||||
num_training_steps=args.max_train_steps * accelerator.num_processes,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
if args.train_text_encoder:
|
||||
unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, text_encoder_one, text_encoder_two, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
else:
|
||||
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1022,6 +1042,9 @@ def main(args):
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one.train()
|
||||
text_encoder_two.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
@@ -1030,12 +1053,16 @@ def main(args):
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
pixel_values = batch["pixel_values"]
|
||||
else:
|
||||
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
|
||||
|
||||
# Convert images to latent space
|
||||
model_input = vae.encode(batch["pixel_values"]).latent_dist.sample()
|
||||
model_input = vae.encode(pixel_values).latent_dist.sample()
|
||||
model_input = model_input * vae.config.scaling_factor
|
||||
model_input = model_input.to(weight_dtype)
|
||||
if args.pretrained_vae_model_name_or_path is None:
|
||||
model_input = model_input.to(weight_dtype)
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
@@ -1051,9 +1078,30 @@ def main(args):
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, batch["input_ids"], added_cond_kwargs=batch["unet_added_conditions"]
|
||||
).sample
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(bsz, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(bsz, 1),
|
||||
}
|
||||
model_pred = unet(
|
||||
noisy_model_input,
|
||||
timesteps,
|
||||
prompt_embeds.repeat(bsz, 1, 1),
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(bsz, 1)}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
prompt=None,
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
)
|
||||
unet_added_conditions.update({"text_embeds": pooled_prompt_embeds.repeat(bsz, 1)})
|
||||
prompt_embeds = prompt_embeds.repeat(bsz, 1, 1)
|
||||
model_pred = unet(
|
||||
noisy_model_input, timesteps, prompt_embeds, added_cond_kwargs=unet_added_conditions
|
||||
).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
@@ -1081,7 +1129,11 @@ def main(args):
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = unet_lora_parameters
|
||||
params_to_clip = (
|
||||
itertools.chain(unet_lora_parameters, text_lora_parameters_one, text_lora_parameters_two)
|
||||
if args.train_text_encoder
|
||||
else unet_lora_parameters
|
||||
)
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
@@ -1132,8 +1184,22 @@ def main(args):
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
if not args.train_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
text_encoder_two = text_encoder_cls_two.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one)
|
||||
if args.train_text_encoder
|
||||
else text_encoder_one,
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two)
|
||||
if args.train_text_encoder
|
||||
else text_encoder_two,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
@@ -1161,9 +1227,11 @@ def main(args):
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
pipeline_args = {"prompt": args.validation_prompt}
|
||||
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)
|
||||
]
|
||||
with torch.cuda.amp.autocast():
|
||||
images = [
|
||||
pipeline(**pipeline_args, generator=generator).images[0]
|
||||
for _ in range(args.num_validation_images)
|
||||
]
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
@@ -1189,16 +1257,32 @@ def main(args):
|
||||
unet = unet.to(torch.float32)
|
||||
unet_lora_layers = unet_attn_processors_state_dict(unet)
|
||||
|
||||
LoraLoaderMixin.save_lora_weights(
|
||||
if args.train_text_encoder:
|
||||
text_encoder_one = accelerator.unwrap_model(text_encoder_one)
|
||||
text_encoder_lora_layers = text_encoder_lora_state_dict(text_encoder_one.to(torch.float32))
|
||||
text_encoder_two = accelerator.unwrap_model(text_encoder_two)
|
||||
text_encoder_2_lora_layers = text_encoder_lora_state_dict(text_encoder_two.to(torch.float32))
|
||||
else:
|
||||
text_encoder_lora_layers = None
|
||||
text_encoder_2_lora_layers = None
|
||||
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=args.output_dir,
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=None,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
vae_path,
|
||||
subfolder="vae" if args.pretrained_vae_model_name_or_path is None else None,
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, vae=vae, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
|
||||
# We train on the simplified learning objective. If we were previously predicting a variance, we need the scheduler to ignore it
|
||||
@@ -1250,6 +1334,7 @@ def main(args):
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
prompt=args.instance_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
vae_path=args.pretrained_vae_model_name_or_path,
|
||||
)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
|
||||
@@ -385,6 +385,42 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
starts_with_unet = all(key.startswith("unet") for key in lora_state_dict.keys())
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_dreambooth_lora_sdxl_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/dreambooth/train_dreambooth_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt photo
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--train_text_encoder
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
|
||||
keys = lora_state_dict.keys()
|
||||
starts_with_unet = all(
|
||||
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
|
||||
)
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -59,7 +59,7 @@ if is_safetensors_available():
|
||||
import safetensors
|
||||
|
||||
if is_transformers_available():
|
||||
from transformers import CLIPTextModel, PreTrainedModel, PreTrainedTokenizer
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection, PreTrainedModel, PreTrainedTokenizer
|
||||
|
||||
if is_accelerate_available():
|
||||
from accelerate import init_empty_weights
|
||||
@@ -108,7 +108,7 @@ class PatchedLoraProjection(nn.Module):
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, CLIPTextModel):
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
@@ -1016,18 +1016,20 @@ class LoraLoaderMixin:
|
||||
warnings.warn(warn_message)
|
||||
|
||||
@classmethod
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, lora_scale=1.0):
|
||||
def load_lora_into_text_encoder(cls, state_dict, network_alpha, text_encoder, prefix=None, lora_scale=1.0):
|
||||
"""
|
||||
This will load the LoRA layers specified in `state_dict` into `text_encoder`
|
||||
|
||||
Parameters:
|
||||
state_dict (`dict`):
|
||||
A standard state dict containing the lora layer parameters. The key shoult be prefixed with an
|
||||
A standard state dict containing the lora layer parameters. The key should be prefixed with an
|
||||
additional `text_encoder` to distinguish between unet lora layers.
|
||||
network_alpha (`float`):
|
||||
See `LoRALinearLayer` for more details.
|
||||
text_encoder (`CLIPTextModel`):
|
||||
The text encoder model to load the LoRA layers into.
|
||||
prefix (`str`):
|
||||
Expected prefix of the `text_encoder` in the `state_dict`.
|
||||
lora_scale (`float`):
|
||||
How much to scale the output of the lora linear layer before it is added with the output of the regular
|
||||
lora layer.
|
||||
@@ -1037,14 +1039,16 @@ class LoraLoaderMixin:
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
if all(key.startswith(cls.unet_name) or key.startswith(cls.text_encoder_name) for key in keys):
|
||||
prefix = cls.text_encoder_name if prefix is None else prefix
|
||||
|
||||
if any(cls.text_encoder_name in key for key in keys):
|
||||
# Load the layers corresponding to text encoder and make necessary adjustments.
|
||||
text_encoder_keys = [k for k in keys if k.startswith(cls.text_encoder_name)]
|
||||
text_encoder_keys = [k for k in keys if k.startswith(prefix)]
|
||||
text_encoder_lora_state_dict = {
|
||||
k.replace(f"{cls.text_encoder_name}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
k.replace(f"{prefix}.", ""): v for k, v in state_dict.items() if k in text_encoder_keys
|
||||
}
|
||||
if len(text_encoder_lora_state_dict) > 0:
|
||||
logger.info(f"Loading {cls.text_encoder_name}.")
|
||||
logger.info(f"Loading {prefix}.")
|
||||
|
||||
if any("to_out_lora" in k for k in text_encoder_lora_state_dict.keys()):
|
||||
# Convert from the old naming convention to the new naming convention.
|
||||
@@ -1184,23 +1188,10 @@ class LoraLoaderMixin:
|
||||
replace `torch.save` with another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
# Create a flat dictionary.
|
||||
state_dict = {}
|
||||
|
||||
# Populate the dictionary.
|
||||
if unet_lora_layers is not None:
|
||||
weights = (
|
||||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
||||
@@ -1222,6 +1213,38 @@ class LoraLoaderMixin:
|
||||
state_dict.update(text_encoder_lora_state_dict)
|
||||
|
||||
# Save the model
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
is_main_process: bool,
|
||||
weight_name: str,
|
||||
save_function: Callable,
|
||||
safe_serialization: bool,
|
||||
):
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
if safe_serialization:
|
||||
|
||||
def save_function(weights, filename):
|
||||
return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
|
||||
|
||||
else:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
@@ -841,3 +842,66 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
|
||||
return (image,)
|
||||
|
||||
return StableDiffusionXLPipelineOutput(images=image)
|
||||
|
||||
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
||||
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
state_dict, network_alpha = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
|
||||
self.load_lora_into_unet(state_dict, network_alpha=network_alpha, unet=self.unet)
|
||||
|
||||
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
||||
if len(text_encoder_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
text_encoder=self.text_encoder,
|
||||
prefix="text_encoder",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
||||
if len(text_encoder_2_state_dict) > 0:
|
||||
self.load_lora_into_text_encoder(
|
||||
text_encoder_2_state_dict,
|
||||
network_alpha=network_alpha,
|
||||
text_encoder=self.text_encoder_2,
|
||||
prefix="text_encoder_2",
|
||||
lora_scale=self.lora_scale,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
is_main_process: bool = True,
|
||||
weight_name: str = None,
|
||||
save_function: Callable = None,
|
||||
safe_serialization: bool = False,
|
||||
):
|
||||
state_dict = {}
|
||||
|
||||
def pack_weights(layers, prefix):
|
||||
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
||||
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
||||
return layers_state_dict
|
||||
|
||||
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
||||
|
||||
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
weight_name=weight_name,
|
||||
save_function=save_function,
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
def _remove_text_encoder_monkey_patch(self):
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
||||
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
||||
|
||||
@@ -21,9 +21,16 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDIMScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionXLPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin, PatchedLoraProjection, text_encoder_attn_modules
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
@@ -399,7 +406,7 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
)
|
||||
self.assertIsInstance(module.processor, attn_proc_class)
|
||||
|
||||
def test_unload_lora(self):
|
||||
def test_unload_lora_sd(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
sd_pipe = StableDiffusionPipeline(**pipeline_components)
|
||||
@@ -503,6 +510,175 @@ class LoraLoaderMixinTests(unittest.TestCase):
|
||||
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
|
||||
|
||||
|
||||
class SDXLLoraLoaderMixinTests(unittest.TestCase):
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
unet = UNet2DConditionModel(
|
||||
block_out_channels=(32, 64),
|
||||
layers_per_block=2,
|
||||
sample_size=32,
|
||||
in_channels=4,
|
||||
out_channels=4,
|
||||
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
|
||||
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
|
||||
# SD2-specific config below
|
||||
attention_head_dim=(2, 4),
|
||||
use_linear_projection=True,
|
||||
addition_embed_type="text_time",
|
||||
addition_time_embed_dim=8,
|
||||
transformer_layers_per_block=(1, 2),
|
||||
projection_class_embeddings_input_dim=80, # 6 * 8 + 32
|
||||
cross_attention_dim=64,
|
||||
)
|
||||
scheduler = EulerDiscreteScheduler(
|
||||
beta_start=0.00085,
|
||||
beta_end=0.012,
|
||||
steps_offset=1,
|
||||
beta_schedule="scaled_linear",
|
||||
timestep_spacing="leading",
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
vae = AutoencoderKL(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=4,
|
||||
sample_size=128,
|
||||
)
|
||||
torch.manual_seed(0)
|
||||
text_encoder_config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
# SD2-specific config below
|
||||
hidden_act="gelu",
|
||||
projection_dim=32,
|
||||
)
|
||||
text_encoder = CLIPTextModel(text_encoder_config)
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
|
||||
text_encoder_2 = CLIPTextModelWithProjection(text_encoder_config)
|
||||
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip", local_files_only=True)
|
||||
|
||||
unet_lora_attn_procs, unet_lora_layers = create_unet_lora_layers(unet)
|
||||
text_encoder_one_lora_layers = create_text_encoder_lora_layers(text_encoder)
|
||||
text_encoder_two_lora_layers = create_text_encoder_lora_layers(text_encoder_2)
|
||||
|
||||
pipeline_components = {
|
||||
"unet": unet,
|
||||
"scheduler": scheduler,
|
||||
"vae": vae,
|
||||
"text_encoder": text_encoder,
|
||||
"text_encoder_2": text_encoder_2,
|
||||
"tokenizer": tokenizer,
|
||||
"tokenizer_2": tokenizer_2,
|
||||
}
|
||||
lora_components = {
|
||||
"unet_lora_layers": unet_lora_layers,
|
||||
"text_encoder_one_lora_layers": text_encoder_one_lora_layers,
|
||||
"text_encoder_two_lora_layers": text_encoder_two_lora_layers,
|
||||
"unet_lora_attn_procs": unet_lora_attn_procs,
|
||||
}
|
||||
return pipeline_components, lora_components
|
||||
|
||||
def get_dummy_inputs(self, with_generator=True):
|
||||
batch_size = 1
|
||||
sequence_length = 10
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
generator = torch.manual_seed(0)
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes)
|
||||
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
|
||||
|
||||
pipeline_inputs = {
|
||||
"prompt": "A painting of a squirrel eating a burger",
|
||||
"num_inference_steps": 2,
|
||||
"guidance_scale": 6.0,
|
||||
"output_type": "np",
|
||||
}
|
||||
if with_generator:
|
||||
pipeline_inputs.update({"generator": generator})
|
||||
|
||||
return noise, input_ids, pipeline_inputs
|
||||
|
||||
def test_lora_save_load(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
sd_pipe = sd_pipe.to(torch_device)
|
||||
sd_pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs()
|
||||
|
||||
original_images = sd_pipe(**pipeline_inputs).images
|
||||
orig_image_slice = original_images[0, -3:, -3:, -1]
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
|
||||
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
lora_images = sd_pipe(**pipeline_inputs).images
|
||||
lora_image_slice = lora_images[0, -3:, -3:, -1]
|
||||
|
||||
# Outputs shouldn't match.
|
||||
self.assertFalse(torch.allclose(torch.from_numpy(orig_image_slice), torch.from_numpy(lora_image_slice)))
|
||||
|
||||
def test_unload_lora_sdxl(self):
|
||||
pipeline_components, lora_components = self.get_dummy_components()
|
||||
_, _, pipeline_inputs = self.get_dummy_inputs(with_generator=False)
|
||||
sd_pipe = StableDiffusionXLPipeline(**pipeline_components)
|
||||
|
||||
original_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
orig_image_slice = original_images[0, -3:, -3:, -1]
|
||||
|
||||
# Emulate training.
|
||||
set_lora_weights(lora_components["unet_lora_layers"].parameters(), randn_weight=True)
|
||||
set_lora_weights(lora_components["text_encoder_one_lora_layers"].parameters(), randn_weight=True)
|
||||
set_lora_weights(lora_components["text_encoder_two_lora_layers"].parameters(), randn_weight=True)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
StableDiffusionXLPipeline.save_lora_weights(
|
||||
save_directory=tmpdirname,
|
||||
unet_lora_layers=lora_components["unet_lora_layers"],
|
||||
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
|
||||
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
|
||||
)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
|
||||
sd_pipe.load_lora_weights(tmpdirname)
|
||||
|
||||
lora_images = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
lora_image_slice = lora_images[0, -3:, -3:, -1]
|
||||
|
||||
# Unload LoRA parameters.
|
||||
sd_pipe.unload_lora_weights()
|
||||
original_images_two = sd_pipe(**pipeline_inputs, generator=torch.manual_seed(0)).images
|
||||
orig_image_slice_two = original_images_two[0, -3:, -3:, -1]
|
||||
|
||||
assert not np.allclose(
|
||||
orig_image_slice, lora_image_slice
|
||||
), "LoRA parameters should lead to a different image slice."
|
||||
assert not np.allclose(
|
||||
orig_image_slice_two, lora_image_slice
|
||||
), "LoRA parameters should lead to a different image slice."
|
||||
assert np.allclose(
|
||||
orig_image_slice, orig_image_slice_two, atol=1e-3
|
||||
), "Unloading LoRA parameters should lead to results similar to what was obtained with the pipeline without any LoRA parameters."
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class LoraIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user