1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Fix textual inversion SDXL and add support for 2nd text encoder (#9010)

* Fix textual inversion SDXL and add support for 2nd text encoder

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

* Fix style/quality of text inv for sdxl

Signed-off-by: Daniel Socek <daniel.socek@intel.com>

---------

Signed-off-by: Daniel Socek <daniel.socek@intel.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Daniel Socek
2024-08-09 10:53:06 -04:00
committed by GitHub
parent 65e30907b5
commit c1079f0887
2 changed files with 81 additions and 12 deletions

View File

@@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \
--output_dir="./textual_inversion_cat_sdxl"
```
For now, only training of the first text encoder is supported.
Training of both text encoders is supported.
### Inference Example
Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`.
Make sure to include the `placeholder_token` in your prompt.
```python
from diffusers import StableDiffusionXLPipeline
import torch
model_id = "./textual_inversion_cat_sdxl"
pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda")
prompt = "A <cat-toy> backpack"
image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack.png")
image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0]
image.save("cat-backpack-prompt_2.png")
```

View File

@@ -135,7 +135,7 @@ def log_validation(
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2,
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
tokenizer=tokenizer_1,
tokenizer_2=tokenizer_2,
unet=unet,
@@ -678,36 +678,54 @@ def main():
f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens)
if num_added_tokens != args.num_vectors:
raise ValueError(
f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different"
" `placeholder_token` that is not already in the tokenizer."
)
# Convert the initializer_token, placeholder_token to ids
token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False)
token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False)
# Check if initializer_token is a single token or a sequence of tokens
if len(token_ids) > 1:
if len(token_ids) > 1 or len(token_ids_2) > 1:
raise ValueError("The initializer token must be a single token.")
initializer_token_id = token_ids[0]
placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens)
initializer_token_id_2 = token_ids_2[0]
placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens)
# Resize the token embeddings as we are adding new special tokens to the tokenizer
text_encoder_1.resize_token_embeddings(len(tokenizer_1))
text_encoder_2.resize_token_embeddings(len(tokenizer_2))
# Initialise the newly added placeholder token with the embeddings of the initializer token
token_embeds = text_encoder_1.get_input_embeddings().weight.data
token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data
with torch.no_grad():
for token_id in placeholder_token_ids:
token_embeds[token_id] = token_embeds[initializer_token_id].clone()
for token_id in placeholder_token_ids_2:
token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone()
# Freeze vae and unet
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder_2.requires_grad_(False)
# Freeze all parameters except for the token embeddings in text encoder
text_encoder_1.text_model.encoder.requires_grad_(False)
text_encoder_1.text_model.final_layer_norm.requires_grad_(False)
text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False)
text_encoder_2.text_model.encoder.requires_grad_(False)
text_encoder_2.text_model.final_layer_norm.requires_grad_(False)
text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False)
if args.gradient_checkpointing:
text_encoder_1.gradient_checkpointing_enable()
text_encoder_2.gradient_checkpointing_enable()
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
@@ -746,7 +764,11 @@ def main():
optimizer_class = torch.optim.AdamW
optimizer = optimizer_class(
text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings
# only optimize the embeddings
[
text_encoder_1.text_model.embeddings.token_embedding.weight,
text_encoder_2.text_model.embeddings.token_embedding.weight,
],
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -786,9 +808,10 @@ def main():
)
text_encoder_1.train()
text_encoder_2.train()
# Prepare everything with our `accelerator`.
text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, optimizer, train_dataloader, lr_scheduler
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision
@@ -866,11 +889,13 @@ def main():
# keep original embeddings as reference
orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone()
orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone()
for epoch in range(first_epoch, args.num_train_epochs):
text_encoder_1.train()
text_encoder_2.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(text_encoder_1):
with accelerator.accumulate([text_encoder_1, text_encoder_2]):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach()
latents = latents * vae.config.scaling_factor
@@ -892,9 +917,7 @@ def main():
.hidden_states[-2]
.to(dtype=weight_dtype)
)
encoder_output_2 = text_encoder_2(
batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True
)
encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True)
encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype)
original_size = [
(batch["original_size"][0][i].item(), batch["original_size"][1][i].item())
@@ -938,11 +961,16 @@ def main():
# Let's make sure we don't update any embedding weights besides the newly added token
index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool)
index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False
index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool)
index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False
with torch.no_grad():
accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[
index_no_updates
] = orig_embeds_params[index_no_updates]
accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[
index_no_updates_2
] = orig_embeds_params_2[index_no_updates_2]
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
@@ -960,6 +988,16 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = f"learned_embeds_2-steps-{global_step}.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if accelerator.is_main_process:
if global_step % args.checkpointing_steps == 0:
@@ -1034,7 +1072,7 @@ def main():
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder_1),
text_encoder_2=text_encoder_2,
text_encoder_2=accelerator.unwrap_model(text_encoder_2),
vae=vae,
unet=unet,
tokenizer=tokenizer_1,
@@ -1052,6 +1090,16 @@ def main():
save_path,
safe_serialization=True,
)
weight_name = "learned_embeds_2.safetensors"
save_path = os.path.join(args.output_dir, weight_name)
save_progress(
text_encoder_2,
placeholder_token_ids_2,
accelerator,
args,
save_path,
safe_serialization=True,
)
if args.push_to_hub:
save_model_card(