From c1079f0887bec9a7d58f49001960febbbbd3f92b Mon Sep 17 00:00:00 2001 From: Daniel Socek Date: Fri, 9 Aug 2024 10:53:06 -0400 Subject: [PATCH] Fix textual inversion SDXL and add support for 2nd text encoder (#9010) * Fix textual inversion SDXL and add support for 2nd text encoder Signed-off-by: Daniel Socek * Fix style/quality of text inv for sdxl Signed-off-by: Daniel Socek --------- Signed-off-by: Daniel Socek Co-authored-by: Sayak Paul --- examples/textual_inversion/README_sdxl.md | 23 +++++- .../textual_inversion_sdxl.py | 70 ++++++++++++++++--- 2 files changed, 81 insertions(+), 12 deletions(-) diff --git a/examples/textual_inversion/README_sdxl.md b/examples/textual_inversion/README_sdxl.md index 560fd841e0..c971ee2a0f 100644 --- a/examples/textual_inversion/README_sdxl.md +++ b/examples/textual_inversion/README_sdxl.md @@ -23,4 +23,25 @@ accelerate launch textual_inversion_sdxl.py \ --output_dir="./textual_inversion_cat_sdxl" ``` -For now, only training of the first text encoder is supported. \ No newline at end of file +Training of both text encoders is supported. + +### Inference Example + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionXLPipeline`. +Make sure to include the `placeholder_token` in your prompt. + +```python +from diffusers import StableDiffusionXLPipeline +import torch + +model_id = "./textual_inversion_cat_sdxl" +pipe = StableDiffusionXLPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda") + +prompt = "A backpack" + +image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] +image.save("cat-backpack.png") + +image = pipe(prompt="", prompt_2=prompt, num_inference_steps=50, guidance_scale=7.5).images[0] +image.save("cat-backpack-prompt_2.png") +``` diff --git a/examples/textual_inversion/textual_inversion_sdxl.py b/examples/textual_inversion/textual_inversion_sdxl.py index 5bd9165fc5..f70a99ecf4 100644 --- a/examples/textual_inversion/textual_inversion_sdxl.py +++ b/examples/textual_inversion/textual_inversion_sdxl.py @@ -135,7 +135,7 @@ def log_validation( pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder_1), - text_encoder_2=text_encoder_2, + text_encoder_2=accelerator.unwrap_model(text_encoder_2), tokenizer=tokenizer_1, tokenizer_2=tokenizer_2, unet=unet, @@ -678,36 +678,54 @@ def main(): f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different" " `placeholder_token` that is not already in the tokenizer." ) + num_added_tokens = tokenizer_2.add_tokens(placeholder_tokens) + if num_added_tokens != args.num_vectors: + raise ValueError( + f"The 2nd tokenizer already contains the token {args.placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) # Convert the initializer_token, placeholder_token to ids token_ids = tokenizer_1.encode(args.initializer_token, add_special_tokens=False) + token_ids_2 = tokenizer_2.encode(args.initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens - if len(token_ids) > 1: + if len(token_ids) > 1 or len(token_ids_2) > 1: raise ValueError("The initializer token must be a single token.") initializer_token_id = token_ids[0] placeholder_token_ids = tokenizer_1.convert_tokens_to_ids(placeholder_tokens) + initializer_token_id_2 = token_ids_2[0] + placeholder_token_ids_2 = tokenizer_2.convert_tokens_to_ids(placeholder_tokens) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder_1.resize_token_embeddings(len(tokenizer_1)) + text_encoder_2.resize_token_embeddings(len(tokenizer_2)) # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder_1.get_input_embeddings().weight.data + token_embeds_2 = text_encoder_2.get_input_embeddings().weight.data with torch.no_grad(): for token_id in placeholder_token_ids: token_embeds[token_id] = token_embeds[initializer_token_id].clone() + for token_id in placeholder_token_ids_2: + token_embeds_2[token_id] = token_embeds_2[initializer_token_id_2].clone() # Freeze vae and unet vae.requires_grad_(False) unet.requires_grad_(False) - text_encoder_2.requires_grad_(False) + # Freeze all parameters except for the token embeddings in text encoder text_encoder_1.text_model.encoder.requires_grad_(False) text_encoder_1.text_model.final_layer_norm.requires_grad_(False) text_encoder_1.text_model.embeddings.position_embedding.requires_grad_(False) + text_encoder_2.text_model.encoder.requires_grad_(False) + text_encoder_2.text_model.final_layer_norm.requires_grad_(False) + text_encoder_2.text_model.embeddings.position_embedding.requires_grad_(False) if args.gradient_checkpointing: text_encoder_1.gradient_checkpointing_enable() + text_encoder_2.gradient_checkpointing_enable() if args.enable_xformers_memory_efficient_attention: if is_xformers_available(): @@ -746,7 +764,11 @@ def main(): optimizer_class = torch.optim.AdamW optimizer = optimizer_class( - text_encoder_1.get_input_embeddings().parameters(), # only optimize the embeddings + # only optimize the embeddings + [ + text_encoder_1.text_model.embeddings.token_embedding.weight, + text_encoder_2.text_model.embeddings.token_embedding.weight, + ], lr=args.learning_rate, betas=(args.adam_beta1, args.adam_beta2), weight_decay=args.adam_weight_decay, @@ -786,9 +808,10 @@ def main(): ) text_encoder_1.train() + text_encoder_2.train() # Prepare everything with our `accelerator`. - text_encoder_1, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( - text_encoder_1, optimizer, train_dataloader, lr_scheduler + text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder_1, text_encoder_2, optimizer, train_dataloader, lr_scheduler ) # For mixed precision training we cast all non-trainable weigths (vae, non-lora text_encoder and non-lora unet) to half-precision @@ -866,11 +889,13 @@ def main(): # keep original embeddings as reference orig_embeds_params = accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight.data.clone() + orig_embeds_params_2 = accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight.data.clone() for epoch in range(first_epoch, args.num_train_epochs): text_encoder_1.train() + text_encoder_2.train() for step, batch in enumerate(train_dataloader): - with accelerator.accumulate(text_encoder_1): + with accelerator.accumulate([text_encoder_1, text_encoder_2]): # Convert images to latent space latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample().detach() latents = latents * vae.config.scaling_factor @@ -892,9 +917,7 @@ def main(): .hidden_states[-2] .to(dtype=weight_dtype) ) - encoder_output_2 = text_encoder_2( - batch["input_ids_2"].reshape(batch["input_ids_1"].shape[0], -1), output_hidden_states=True - ) + encoder_output_2 = text_encoder_2(batch["input_ids_2"], output_hidden_states=True) encoder_hidden_states_2 = encoder_output_2.hidden_states[-2].to(dtype=weight_dtype) original_size = [ (batch["original_size"][0][i].item(), batch["original_size"][1][i].item()) @@ -938,11 +961,16 @@ def main(): # Let's make sure we don't update any embedding weights besides the newly added token index_no_updates = torch.ones((len(tokenizer_1),), dtype=torch.bool) index_no_updates[min(placeholder_token_ids) : max(placeholder_token_ids) + 1] = False + index_no_updates_2 = torch.ones((len(tokenizer_2),), dtype=torch.bool) + index_no_updates_2[min(placeholder_token_ids_2) : max(placeholder_token_ids_2) + 1] = False with torch.no_grad(): accelerator.unwrap_model(text_encoder_1).get_input_embeddings().weight[ index_no_updates ] = orig_embeds_params[index_no_updates] + accelerator.unwrap_model(text_encoder_2).get_input_embeddings().weight[ + index_no_updates_2 + ] = orig_embeds_params_2[index_no_updates_2] # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: @@ -960,6 +988,16 @@ def main(): save_path, safe_serialization=True, ) + weight_name = f"learned_embeds_2-steps-{global_step}.safetensors" + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder_2, + placeholder_token_ids_2, + accelerator, + args, + save_path, + safe_serialization=True, + ) if accelerator.is_main_process: if global_step % args.checkpointing_steps == 0: @@ -1034,7 +1072,7 @@ def main(): pipeline = DiffusionPipeline.from_pretrained( args.pretrained_model_name_or_path, text_encoder=accelerator.unwrap_model(text_encoder_1), - text_encoder_2=text_encoder_2, + text_encoder_2=accelerator.unwrap_model(text_encoder_2), vae=vae, unet=unet, tokenizer=tokenizer_1, @@ -1052,6 +1090,16 @@ def main(): save_path, safe_serialization=True, ) + weight_name = "learned_embeds_2.safetensors" + save_path = os.path.join(args.output_dir, weight_name) + save_progress( + text_encoder_2, + placeholder_token_ids_2, + accelerator, + args, + save_path, + safe_serialization=True, + ) if args.push_to_hub: save_model_card(