diff --git a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py index 04f5296666..0eeeb821d6 100644 --- a/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py +++ b/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py @@ -20,6 +20,7 @@ import itertools import logging import math import os +import re import shutil import warnings from pathlib import Path @@ -41,7 +42,7 @@ from peft import LoraConfig from peft.utils import get_peft_model_state_dict from PIL import Image from PIL.ImageOps import exif_transpose -from safetensors.torch import save_file +from safetensors.torch import load_file, save_file from torch.utils.data import Dataset from torchvision import transforms from tqdm.auto import tqdm @@ -58,7 +59,13 @@ from diffusers import ( from diffusers.loaders import LoraLoaderMixin from diffusers.optimization import get_scheduler from diffusers.training_utils import compute_snr -from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, is_wandb_available +from diffusers.utils import ( + check_min_version, + convert_all_state_dict_to_peft, + convert_state_dict_to_diffusers, + convert_state_dict_to_kohya, + is_wandb_available, +) from diffusers.utils.import_utils import is_xformers_available @@ -93,10 +100,17 @@ def save_model_card( img_str += f""" - text: '{instance_prompt}' """ - + embeddings_filename = f"{repo_folder}_emb" + instance_prompt_webui = re.sub(r"", "", re.sub(r"", embeddings_filename, instance_prompt, count=1)) + ti_keys = ", ".join(f'"{match}"' for match in re.findall(r"", instance_prompt)) + if instance_prompt_webui != embeddings_filename: + instance_prompt_sentence = f"For example, `{instance_prompt_webui}`" + else: + instance_prompt_sentence = "" trigger_str = f"You should use {instance_prompt} to trigger the image generation." diffusers_imports_pivotal = "" diffusers_example_pivotal = "" + webui_example_pivotal = "" if train_text_encoder_ti: trigger_str = ( "To trigger image generation of trained concept(or concepts) replace each concept identifier " @@ -105,11 +119,16 @@ def save_model_card( diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download from safetensors.torch import load_file """ - diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename="embeddings.safetensors", repo_type="model") + diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model") state_dict = load_file(embedding_path) -pipeline.load_textual_inversion(state_dict["clip_l"], token=["", ""], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) -pipeline.load_textual_inversion(state_dict["clip_g"], token=["", ""], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) +pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer) +pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2) """ + webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**. + - Place it on it on your `embeddings` folder + - Use it by adding `{embeddings_filename}` to your prompt. {instance_prompt_sentence} + (you need both the LoRA and the embeddings as they were trained together for this LoRA) + """ if token_abstraction_dict: for key, value in token_abstraction_dict.items(): tokens = "".join(value) @@ -141,9 +160,14 @@ license: openrail++ ### These are {repo_id} LoRA adaption weights for {base_model}. -## Trigger words +## Download model -{trigger_str} +### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke + +- **LoRA**: download **[`{repo_folder}.safetensors` here 💾](/{repo_id}/blob/main/{repo_folder}.safetensors)**. + - Place it on your `models/Lora` folder. + - On AUTOMATIC1111, load the LoRA by adding `` to your prompt. On ComfyUI just [load it as a regular LoRA](https://comfyanonymous.github.io/ComfyUI_examples/lora/). +{webui_example_pivotal} ## Use it with the [🧨 diffusers library](https://github.com/huggingface/diffusers) @@ -159,16 +183,12 @@ image = pipeline('{validation_prompt if validation_prompt else instance_prompt}' For more details, including weighting, merging and fusing LoRAs, check the [documentation on loading LoRAs in diffusers](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) -## Download model +## Trigger words -### Use it with UIs such as AUTOMATIC1111, Comfy UI, SD.Next, Invoke - -- Download the LoRA *.safetensors [here](/{repo_id}/blob/main/pytorch_lora_weights.safetensors). Rename it and place it on your Lora folder. -- Download the text embeddings *.safetensors [here](/{repo_id}/blob/main/embeddings.safetensors). Rename it and place it on it on your embeddings folder. - -All [Files & versions](/{repo_id}/tree/main). +{trigger_str} ## Details +All [Files & versions](/{repo_id}/tree/main). The weights were trained using [🧨 diffusers Advanced Dreambooth Training Script](https://github.com/huggingface/diffusers/blob/main/examples/advanced_diffusion_training/train_dreambooth_lora_sdxl_advanced.py). @@ -2035,8 +2055,15 @@ def main(args): if args.train_text_encoder_ti: embedding_handler.save_embeddings( - f"{args.output_dir}/embeddings.safetensors", + f"{args.output_dir}/{args.output_dir}_emb.safetensors", ) + + # Conver to WebUI format + lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors") + peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict) + kohya_state_dict = convert_state_dict_to_kohya(peft_state_dict) + save_file(kohya_state_dict, f"{args.output_dir}/{args.output_dir}.safetensors") + save_model_card( model_id if not args.push_to_hub else repo_id, images=images,