mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[examples/advanced_diffusion_training] bug fixes and improvements for LoRA Dreambooth SDXL advanced training script (#5935)
* imports and readme bug fixes * bug fix - ensures text_encoder params are dtype==float32 (when using pivotal tuning) even if the rest of the model is loaded in fp16 * added pivotal tuning to readme * mapping token identifier to new inserted token in validation prompt (if used) * correct default value of --train_text_encoder_frac * change default value of --adam_weight_decay_text_encoder * validation prompt generations when using pivotal tuning bug fix * style fix * textual inversion embeddings name change * style fix * bug fix - stopping text encoder optimization halfway * readme - will include token abstraction and new inserted tokens when using pivotal tuning - added type to --num_new_tokens_per_abstraction * style fix --------- Co-authored-by: Linoy Tsaban <linoy@huggingface.co>
This commit is contained in:
@@ -54,7 +54,7 @@ from diffusers import (
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import LoraLoaderMixin
|
||||
from diffusers.models.lora import LoRALinearLayer, text_encoder_lora_state_dict
|
||||
from diffusers.models.lora import LoRALinearLayer
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.training_utils import compute_snr, unet_lora_state_dict
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
@@ -67,11 +67,46 @@ check_min_version("0.24.0.dev0")
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: This function should be removed once training scripts are rewritten in PEFT
|
||||
def text_encoder_lora_state_dict(text_encoder):
|
||||
state_dict = {}
|
||||
|
||||
def text_encoder_attn_modules(text_encoder):
|
||||
from transformers import CLIPTextModel, CLIPTextModelWithProjection
|
||||
|
||||
attn_modules = []
|
||||
|
||||
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
||||
for i, layer in enumerate(text_encoder.text_model.encoder.layers):
|
||||
name = f"text_model.encoder.layers.{i}.self_attn"
|
||||
mod = layer.self_attn
|
||||
attn_modules.append((name, mod))
|
||||
|
||||
return attn_modules
|
||||
|
||||
for name, module in text_encoder_attn_modules(text_encoder):
|
||||
for k, v in module.q_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.q_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.k_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.k_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.v_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.v_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
for k, v in module.out_proj.lora_linear_layer.state_dict().items():
|
||||
state_dict[f"{name}.out_proj.lora_linear_layer.{k}"] = v
|
||||
|
||||
return state_dict
|
||||
|
||||
|
||||
def save_model_card(
|
||||
repo_id: str,
|
||||
images=None,
|
||||
base_model=str,
|
||||
train_text_encoder=False,
|
||||
train_text_encoder_ti=False,
|
||||
token_abstraction_dict=None,
|
||||
instance_prompt=str,
|
||||
validation_prompt=str,
|
||||
repo_folder=None,
|
||||
@@ -83,10 +118,23 @@ def save_model_card(
|
||||
img_str += f"""
|
||||
- text: '{validation_prompt if validation_prompt else ' ' }'
|
||||
output:
|
||||
url: >-
|
||||
url:
|
||||
"image_{i}.png"
|
||||
"""
|
||||
|
||||
trigger_str = f"You should use {instance_prompt} to trigger the image generation."
|
||||
if train_text_encoder_ti:
|
||||
trigger_str = (
|
||||
"To trigger image generation of trained concept(or concepts) replace each concept identifier "
|
||||
"in you prompt with the new inserted tokens:\n"
|
||||
)
|
||||
if token_abstraction_dict:
|
||||
for key, value in token_abstraction_dict.items():
|
||||
tokens = "".join(value)
|
||||
trigger_str += f"""
|
||||
to trigger concept {key}-> use {tokens} in your prompt \n
|
||||
"""
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
tags:
|
||||
@@ -96,9 +144,7 @@ tags:
|
||||
- diffusers
|
||||
- lora
|
||||
- template:sd-lora
|
||||
widget:
|
||||
{img_str}
|
||||
---
|
||||
base_model: {base_model}
|
||||
instance_prompt: {instance_prompt}
|
||||
license: openrail++
|
||||
@@ -112,14 +158,19 @@ license: openrail++
|
||||
|
||||
## Model description
|
||||
|
||||
These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
### These are {repo_id} LoRA adaption weights for {base_model}.
|
||||
|
||||
The weights were trained using [DreamBooth](https://dreambooth.github.io/).
|
||||
|
||||
LoRA for the text encoder was enabled: {train_text_encoder}.
|
||||
|
||||
Pivotal tuning was enabled: {train_text_encoder_ti}.
|
||||
|
||||
Special VAE used for training: {vae_path}.
|
||||
|
||||
## Trigger words
|
||||
|
||||
You should use {instance_prompt} to trigger the image generation.
|
||||
{trigger_str}
|
||||
|
||||
## Download model
|
||||
|
||||
@@ -244,6 +295,7 @@ def parse_args(input_args=None):
|
||||
|
||||
parser.add_argument(
|
||||
"--num_new_tokens_per_abstraction",
|
||||
type=int,
|
||||
default=2,
|
||||
help="number of new tokens inserted to the tokenizers per token_abstraction value when "
|
||||
"--train_text_encoder_ti = True. By default, each --token_abstraction (e.g. TOK) is mapped to 2 new "
|
||||
@@ -455,7 +507,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--train_text_encoder_frac",
|
||||
type=float,
|
||||
default=0.5,
|
||||
default=1.0,
|
||||
help=("The percentage of epochs to perform text encoder tuning"),
|
||||
)
|
||||
|
||||
@@ -488,7 +540,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument("--prodigy_decouple", type=bool, default=True, help="Use AdamW style decoupled weight decay")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-04, help="Weight decay to use for unet params")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
|
||||
"--adam_weight_decay_text_encoder", type=float, default=None, help="Weight decay to use for text_encoder"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
@@ -679,12 +731,19 @@ class TokenEmbeddingsHandler:
|
||||
def save_embeddings(self, file_path: str):
|
||||
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
|
||||
tensors = {}
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
self.tokenizers[0]
|
||||
), "Tokenizers should be the same."
|
||||
new_token_embeddings = text_encoder.text_model.embeddings.token_embedding.weight.data[self.train_ids]
|
||||
tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
||||
|
||||
# New tokens for each text encoder are saved under "clip_l" (for text_encoder 0), "clip_g" (for
|
||||
# text_encoder 1) to keep compatible with the ecosystem.
|
||||
# Note: When loading with diffusers, any name can work - simply specify in inference
|
||||
tensors[idx_to_text_encoder_name[idx]] = new_token_embeddings
|
||||
# tensors[f"text_encoders_{idx}"] = new_token_embeddings
|
||||
|
||||
save_file(tensors, file_path)
|
||||
|
||||
@@ -696,19 +755,6 @@ class TokenEmbeddingsHandler:
|
||||
def device(self):
|
||||
return self.text_encoders[0].device
|
||||
|
||||
# def _load_embeddings(self, loaded_embeddings, tokenizer, text_encoder):
|
||||
# # Assuming new tokens are of the format <s_i>
|
||||
# self.inserting_toks = [f"<s{i}>" for i in range(loaded_embeddings.shape[0])]
|
||||
# special_tokens_dict = {"additional_special_tokens": self.inserting_toks}
|
||||
# tokenizer.add_special_tokens(special_tokens_dict)
|
||||
# text_encoder.resize_token_embeddings(len(tokenizer))
|
||||
#
|
||||
# self.train_ids = tokenizer.convert_tokens_to_ids(self.inserting_toks)
|
||||
# assert self.train_ids is not None, "New tokens could not be converted to IDs."
|
||||
# text_encoder.text_model.embeddings.token_embedding.weight.data[
|
||||
# self.train_ids
|
||||
# ] = loaded_embeddings.to(device=self.device).to(dtype=self.dtype)
|
||||
|
||||
@torch.no_grad()
|
||||
def retract_embeddings(self):
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
@@ -730,15 +776,6 @@ class TokenEmbeddingsHandler:
|
||||
new_embeddings = new_embeddings * (off_ratio**0.1)
|
||||
text_encoder.text_model.embeddings.token_embedding.weight.data[index_updates] = new_embeddings
|
||||
|
||||
# def load_embeddings(self, file_path: str):
|
||||
# with safe_open(file_path, framework="pt", device=self.device.type) as f:
|
||||
# for idx in range(len(self.text_encoders)):
|
||||
# text_encoder = self.text_encoders[idx]
|
||||
# tokenizer = self.tokenizers[idx]
|
||||
#
|
||||
# loaded_embeddings = f.get_tensor(f"text_encoders_{idx}")
|
||||
# self._load_embeddings(loaded_embeddings, tokenizer, text_encoder)
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
@@ -1216,6 +1253,8 @@ def main(args):
|
||||
text_lora_parameters_one = []
|
||||
for name, param in text_encoder_one.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_one.append(param)
|
||||
else:
|
||||
@@ -1223,6 +1262,8 @@ def main(args):
|
||||
text_lora_parameters_two = []
|
||||
for name, param in text_encoder_two.named_parameters():
|
||||
if "token_embedding" in name:
|
||||
# ensure that dtype is float32, even if rest of the model that isn't trained is loaded in fp16
|
||||
param = param.to(dtype=torch.float32)
|
||||
param.requires_grad = True
|
||||
text_lora_parameters_two.append(param)
|
||||
else:
|
||||
@@ -1309,12 +1350,16 @@ def main(args):
|
||||
# different learning rate for text encoder and unet
|
||||
text_lora_parameters_one_with_lr = {
|
||||
"params": text_lora_parameters_one,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder
|
||||
if args.adam_weight_decay_text_encoder
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
text_lora_parameters_two_with_lr = {
|
||||
"params": text_lora_parameters_two,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder,
|
||||
"weight_decay": args.adam_weight_decay_text_encoder
|
||||
if args.adam_weight_decay_text_encoder
|
||||
else args.adam_weight_decay,
|
||||
"lr": args.text_encoder_lr if args.text_encoder_lr else args.learning_rate,
|
||||
}
|
||||
params_to_optimize = [
|
||||
@@ -1494,6 +1539,12 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
if args.train_text_encoder_ti and args.validation_prompt:
|
||||
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
@@ -1593,27 +1644,10 @@ def main(args):
|
||||
if epoch == num_train_epochs_text_encoder:
|
||||
print("PIVOT HALFWAY", epoch)
|
||||
# stopping optimization of text_encoder params
|
||||
params_to_optimize = params_to_optimize[:1]
|
||||
# reinitializing the optimizer to optimize only on unet params
|
||||
if args.optimizer.lower() == "prodigy":
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
beta3=args.prodigy_beta3,
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
decouple=args.prodigy_decouple,
|
||||
use_bias_correction=args.prodigy_use_bias_correction,
|
||||
safeguard_warmup=args.prodigy_safeguard_warmup,
|
||||
)
|
||||
else: # AdamW or 8-bit-AdamW
|
||||
optimizer = optimizer_class(
|
||||
params_to_optimize,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
# re setting the optimizer to optimize only on unet params
|
||||
optimizer.param_groups[1]["lr"] = 0.0
|
||||
optimizer.param_groups[2]["lr"] = 0.0
|
||||
|
||||
else:
|
||||
# still optimizng the text encoder
|
||||
text_encoder_one.train()
|
||||
@@ -1628,7 +1662,7 @@ def main(args):
|
||||
with accelerator.accumulate(unet):
|
||||
pixel_values = batch["pixel_values"].to(dtype=vae.dtype)
|
||||
prompts = batch["prompts"]
|
||||
print(prompts)
|
||||
# print(prompts)
|
||||
# encode batch prompts when custom prompts are provided for each image -
|
||||
if train_dataset.custom_instance_prompts:
|
||||
if freeze_text_encoder:
|
||||
@@ -1801,7 +1835,7 @@ def main(args):
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
if not args.train_text_encoder:
|
||||
if freeze_text_encoder:
|
||||
text_encoder_one = text_encoder_cls_one.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
@@ -1948,6 +1982,8 @@ def main(args):
|
||||
images=images,
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
train_text_encoder=args.train_text_encoder,
|
||||
train_text_encoder_ti=args.train_text_encoder_ti,
|
||||
token_abstraction_dict=train_dataset.token_abstraction_dict,
|
||||
instance_prompt=args.instance_prompt,
|
||||
validation_prompt=args.validation_prompt,
|
||||
repo_folder=args.output_dir,
|
||||
|
||||
Reference in New Issue
Block a user