mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[advanced dreambooth lora sdxl script] new features + bug fixes (#6691)
* add noise_offset param * micro conditioning - wip * image processing adjusted and moved to support micro conditioning * change time ids to be computed inside train loop * change time ids to be computed inside train loop * change time ids to be computed inside train loop * time ids shape fix * move token replacement of validation prompt to the same section of instance prompt and class prompt * add offset noise to sd15 advanced script * fix token loading during validation * fix token loading during validation in sdxl script * a little clean * style * a little clean * style * sdxl script - a little clean + minor path fix sd 1.5 script - change default resolution value * ad 1.5 script - minor path fix * fix missing comma in code example in model card * clean up commented lines * style * remove time ids computed outside training loop - no longer used now that we utilize micro-conditioning, as all time ids are now computed inside the training loop * style * [WIP] - added draft readme, building off of examples/dreambooth/README.md * readme * readme * readme * readme * readme * readme * readme * readme * removed --crops_coords_top_left from CLI args * style * fix missing shape bug due to missing RGB if statement * add blog mention at the start of the reamde as well * Update examples/advanced_diffusion_training/README.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * change note to render nicely as well --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
244
examples/advanced_diffusion_training/README.md
Normal file
244
examples/advanced_diffusion_training/README.md
Normal file
@@ -0,0 +1,244 @@
|
||||
# Advanced diffusion training examples
|
||||
|
||||
## Train Dreambooth LoRA with Stable Diffusion XL
|
||||
> [!TIP]
|
||||
> 💡 This example follows the techniques and recommended practices covered in the blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script). Make sure to check it out before starting 🤗
|
||||
|
||||
[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject.
|
||||
|
||||
LoRA - Low-Rank Adaption of Large Language Models, was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted towards new training images via a `scale` parameter.
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
The `train_dreambooth_lora_sdxl_advanced.py` script shows how to implement dreambooth-LoRA, combining the training process shown in `train_dreambooth_lora_sdxl.py`, with
|
||||
advanced features and techniques, inspired and built upon contributions by [Nataniel Ruiz](https://twitter.com/natanielruizg): [Dreambooth](https://dreambooth.github.io), [Rinon Gal](https://twitter.com/RinonGal): [Textual Inversion](https://textual-inversion.github.io), [Ron Mokady](https://twitter.com/MokadyRon): [Pivotal Tuning](https://arxiv.org/abs/2106.05744), [Simo Ryu](https://twitter.com/cloneofsimo): [cog-sdxl](https://github.com/replicate/cog-sdxl),
|
||||
[Kohya](https://twitter.com/kohya_tech/): [sd-scripts](https://github.com/kohya-ss/sd-scripts), [The Last Ben](https://twitter.com/__TheBen): [fast-stable-diffusion](https://github.com/TheLastBen/fast-stable-diffusion) ❤️
|
||||
|
||||
> [!NOTE]
|
||||
> 💡If this is your first time training a Dreambooth LoRA, congrats!🥳
|
||||
> You might want to familiarize yourself more with the techniques: [Dreambooth blog](https://huggingface.co/blog/dreambooth), [Using LoRA for Efficient Stable Diffusion Fine-Tuning blog](https://huggingface.co/blog/lora)
|
||||
|
||||
📚 Read more about the advanced features and best practices in this community derived blog post: [LoRA training scripts of the world, unite!](https://huggingface.co/blog/sdxl_lora_advanced_script)
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/advanced_diffusion_training` folder and run
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
|
||||
|
||||
### Pivotal Tuning
|
||||
**Training with text encoder(s)**
|
||||
|
||||
Alongside the UNet, LoRA fine-tuning of the text encoders is also supported. In addition to the text encoder optimization
|
||||
available with `train_dreambooth_lora_sdxl_advanced.py`, in the advanced script **pivotal tuning** is also supported.
|
||||
[pivotal tuning](https://huggingface.co/blog/sdxl_lora_advanced_script#pivotal-tuning) combines Textual Inversion with regular diffusion fine-tuning -
|
||||
we insert new tokens into the text encoders of the model, instead of reusing existing ones.
|
||||
We then optimize the newly-inserted token embeddings to represent the new concept.
|
||||
|
||||
To do so, just specify `--train_text_encoder_ti` while launching training (for regular text encoder optimizations, use `--train_text_encoder`).
|
||||
Please keep the following points in mind:
|
||||
|
||||
* SDXL has two text encoders. So, we fine-tune both using LoRA.
|
||||
* When not fine-tuning the text encoders, we ALWAYS precompute the text embeddings to save memoםהקרry.
|
||||
|
||||
|
||||
### 3D icon example
|
||||
|
||||
Now let's get our dataset. For this example we will use some cool images of 3d rendered icons: https://huggingface.co/datasets/linoyts/3d_icon.
|
||||
|
||||
Let's first download it locally:
|
||||
|
||||
```python
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
local_dir = "./3d_icon"
|
||||
snapshot_download(
|
||||
"LinoyTsaban/3d_icon",
|
||||
local_dir=local_dir, repo_type="dataset",
|
||||
ignore_patterns=".gitattributes",
|
||||
)
|
||||
```
|
||||
|
||||
Let's review some of the advanced features we're going to be using for this example:
|
||||
- **custom captions**:
|
||||
To use custom captioning, first ensure that you have the datasets library installed, otherwise you can install it by
|
||||
```bash
|
||||
pip install datasets
|
||||
```
|
||||
|
||||
Now we'll simply specify the name of the dataset and caption column (in this case it's "prompt")
|
||||
|
||||
```
|
||||
--dataset_name=./3d_icon
|
||||
--caption_column=prompt
|
||||
```
|
||||
|
||||
You can also load a dataset straight from by specifying it's name in `dataset_name`.
|
||||
Look [here](https://huggingface.co/blog/sdxl_lora_advanced_script#custom-captioning) for more info on creating/loadin your own caption dataset.
|
||||
|
||||
- **optimizer**: for this example, we'll use [prodigy](https://huggingface.co/blog/sdxl_lora_advanced_script#adaptive-optimizers) - an adaptive optimizer
|
||||
- **pivotal tuning**
|
||||
- **min SNR gamma**
|
||||
|
||||
**Now, we can launch training:**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="./3d_icon"
|
||||
export OUTPUT_DIR="3d-icon-SDXL-LoRA"
|
||||
export VAE_PATH="madebyollin/sdxl-vae-fp16-fix"
|
||||
|
||||
accelerate launch train_dreambooth_lora_sdxl_advanced.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--pretrained_vae_model_name_or_path=$VAE_PATH \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--instance_prompt="3d icon in the style of TOK" \
|
||||
--validation_prompt="a TOK icon of an astronaut riding a horse, in the style of TOK" \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--caption_column="prompt" \
|
||||
--mixed_precision="bf16" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=3 \
|
||||
--repeats=1 \
|
||||
--report_to="wandb"\
|
||||
--gradient_accumulation_steps=1 \
|
||||
--gradient_checkpointing \
|
||||
--learning_rate=1.0 \
|
||||
--text_encoder_lr=1.0 \
|
||||
--optimizer="prodigy"\
|
||||
--train_text_encoder_ti\
|
||||
--train_text_encoder_ti_frac=0.5\
|
||||
--snr_gamma=5.0 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--rank=8 \
|
||||
--max_train_steps=1000 \
|
||||
--checkpointing_steps=2000 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using the following flags in the command above:
|
||||
|
||||
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
||||
* `validation_prompt` and `validation_epochs` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
1. starting with loading the unet lora weights
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download, upload_file
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers.models import AutoencoderKL
|
||||
from safetensors.torch import load_file
|
||||
|
||||
username = "linoyts"
|
||||
repo_id = f"{username}/3d-icon-SDXL-LoRA"
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"stabilityai/stable-diffusion-xl-base-1.0",
|
||||
torch_dtype=torch.float16,
|
||||
variant="fp16",
|
||||
).to("cuda")
|
||||
|
||||
|
||||
pipe.load_lora_weights(repo_id, weight_name="pytorch_lora_weights.safetensors")
|
||||
```
|
||||
2. now we load the pivotal tuning embeddings
|
||||
|
||||
```python
|
||||
text_encoders = [pipe.text_encoder, pipe.text_encoder_2]
|
||||
tokenizers = [pipe.tokenizer, pipe.tokenizer_2]
|
||||
|
||||
embedding_path = hf_hub_download(repo_id=repo_id, filename="3d-icon-SDXL-LoRA_emb.safetensors", repo_type="model")
|
||||
|
||||
state_dict = load_file(embedding_path)
|
||||
# load embeddings of text_encoder 1 (CLIP ViT-L/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_l"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder, tokenizer=pipe.tokenizer)
|
||||
# load embeddings of text_encoder 2 (CLIP ViT-G/14)
|
||||
pipe.load_textual_inversion(state_dict["clip_g"], token=["<s0>", "<s1>"], text_encoder=pipe.text_encoder_2, tokenizer=pipe.tokenizer_2)
|
||||
```
|
||||
|
||||
3. let's generate images
|
||||
|
||||
```python
|
||||
instance_token = "<s0><s1>"
|
||||
prompt = f"a {instance_token} icon of an orange llama eating ramen, in the style of {instance_token}"
|
||||
|
||||
image = pipe(prompt=prompt, num_inference_steps=25, cross_attention_kwargs={"scale": 1.0}).images[0]
|
||||
image.save("llama.png")
|
||||
```
|
||||
|
||||
### Comfy UI / AUTOMATIC1111 Inference
|
||||
The new script fully supports textual inversion loading with Comfy UI and AUTOMATIC1111 formats!
|
||||
|
||||
**AUTOMATIC1111 / SD.Next** \
|
||||
In AUTOMATIC1111/SD.Next we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a WebUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory.
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `embeddings` directory.
|
||||
|
||||
You can then run inference by prompting `a y2k_emb webpage about the movie Mean Girls <lora:y2k:0.9>`. You can use the `y2k_emb` token normally, including increasing its weight by doing `(y2k_emb:1.2)`.
|
||||
|
||||
**ComfyUI** \
|
||||
In ComfyUI we will load a LoRA and a textual embedding at the same time.
|
||||
- *LoRA*: Besides the diffusers format, the script will also train a ComfyUI compatible LoRA. It is generated as `{your_lora_name}.safetensors`. You can then include it in your `models/Lora` directory. Then you will load the LoRALoader node and hook that up with your model and CLIP. [Official guide for loading LoRAs](https://comfyanonymous.github.io/ComfyUI_examples/lora/)
|
||||
- *Embedding*: the embedding is the same for diffusers and WebUI. You can download your `{lora_name}_emb.safetensors` file from a trained model, and include it in your `models/embeddings` directory and use it in your prompts like `embedding:y2k_emb`. [Official guide for loading embeddings](https://comfyanonymous.github.io/ComfyUI_examples/textual_inversion_embeddings/).
|
||||
-
|
||||
### Specifying a better VAE
|
||||
|
||||
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
||||
|
||||
|
||||
### Tips and Tricks
|
||||
Check out [these recommended practices](https://huggingface.co/blog/sdxl_lora_advanced_script#additional-good-practices)
|
||||
|
||||
## Running on Colab Notebook
|
||||
Check out [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_advanced_example.ipynb).
|
||||
to train using the advanced features (including pivotal tuning), and [this notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/SDXL_DreamBooth_LoRA_.ipynb) to train on a free colab, using some of the advanced features (excluding pivotal tuning)
|
||||
|
||||
7
examples/advanced_diffusion_training/requirements.txt
Normal file
7
examples/advanced_diffusion_training/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -119,10 +119,9 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
"""
|
||||
webui_example_pivotal = f"""- *Embeddings*: download **[`{embeddings_filename}.safetensors` here 💾](/{repo_id}/blob/main/{embeddings_filename}.safetensors)**.
|
||||
- Place it on it on your `embeddings` folder
|
||||
@@ -389,7 +388,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=1024,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
@@ -645,6 +644,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -745,10 +745,11 @@ class TokenEmbeddingsHandler:
|
||||
|
||||
idx += 1
|
||||
|
||||
# copied from train_dreambooth_lora_sdxl_advanced.py
|
||||
def save_embeddings(self, file_path: str):
|
||||
assert self.train_ids is not None, "Initialize new tokens before saving embeddings."
|
||||
tensors = {}
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14
|
||||
# text_encoder_0 - CLIP ViT-L/14, text_encoder_1 - CLIP ViT-G/14 - TODO - change for sd
|
||||
idx_to_text_encoder_name = {0: "clip_l", 1: "clip_g"}
|
||||
for idx, text_encoder in enumerate(self.text_encoders):
|
||||
assert text_encoder.text_model.embeddings.token_embedding.weight.data.shape[0] == len(
|
||||
@@ -1634,6 +1635,11 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1788,6 +1794,7 @@ def main(args):
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
revision=args.revision,
|
||||
@@ -1860,6 +1867,11 @@ def main(args):
|
||||
unet_lora_layers=unet_lora_layers,
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -1895,6 +1907,18 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -1917,11 +1941,6 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
@@ -20,6 +20,7 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import shutil
|
||||
import warnings
|
||||
@@ -45,6 +46,7 @@ from PIL.ImageOps import exif_transpose
|
||||
from safetensors.torch import load_file, save_file
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -121,7 +123,7 @@ def save_model_card(
|
||||
diffusers_imports_pivotal = """from huggingface_hub import hf_hub_download
|
||||
from safetensors.torch import load_file
|
||||
"""
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors' repo_type="model")
|
||||
diffusers_example_pivotal = f"""embedding_path = hf_hub_download(repo_id='{repo_id}', filename='{embeddings_filename}.safetensors', repo_type="model")
|
||||
state_dict = load_file(embedding_path)
|
||||
pipeline.load_textual_inversion(state_dict["clip_l"], token=[{ti_keys}], text_encoder=pipeline.text_encoder, tokenizer=pipeline.tokenizer)
|
||||
pipeline.load_textual_inversion(state_dict["clip_g"], token=[{ti_keys}], text_encoder=pipeline.text_encoder_2, tokenizer=pipeline.tokenizer_2)
|
||||
@@ -397,18 +399,6 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -418,6 +408,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -659,6 +654,7 @@ def parse_args(input_args=None):
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
@@ -901,6 +897,41 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -930,12 +961,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -966,6 +997,8 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -976,7 +1009,12 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1198,7 +1236,9 @@ def main(args):
|
||||
args.instance_prompt = args.instance_prompt.replace(token_abs, "".join(token_replacement))
|
||||
if args.with_prior_preservation:
|
||||
args.class_prompt = args.class_prompt.replace(token_abs, "".join(token_replacement))
|
||||
|
||||
if args.validation_prompt:
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
# initialize the new tokens for textual inversion
|
||||
embedding_handler = TokenEmbeddingsHandler(
|
||||
[text_encoder_one, text_encoder_two], [tokenizer_one, tokenizer_two]
|
||||
@@ -1539,11 +1579,11 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids():
|
||||
def compute_time_ids(crops_coords_top_left, original_size=None):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
if original_size is None:
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1560,9 +1600,6 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1573,7 +1610,6 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if freeze_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1588,9 +1624,6 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
# if --train_text_encoder_ti we need add_special_tokens to be True fo textual inversion
|
||||
add_special_tokens = True if args.train_text_encoder_ti else False
|
||||
@@ -1613,12 +1646,6 @@ def main(args):
|
||||
tokens_one = torch.cat([tokens_one, class_tokens_one], dim=0)
|
||||
tokens_two = torch.cat([tokens_two, class_tokens_two], dim=0)
|
||||
|
||||
if args.train_text_encoder_ti and args.validation_prompt:
|
||||
# replace instances of --token_abstraction in validation prompt with the new tokens: "<si><si+1>" etc.
|
||||
for token_abs, token_replacement in train_dataset.token_abstraction_dict.items():
|
||||
args.validation_prompt = args.validation_prompt.replace(token_abs, "".join(token_replacement))
|
||||
print("validation prompt:", args.validation_prompt)
|
||||
|
||||
if args.cache_latents:
|
||||
latents_cache = []
|
||||
for batch in tqdm(train_dataloader, desc="Caching latents"):
|
||||
@@ -1778,6 +1805,12 @@ def main(args):
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(model_input)
|
||||
if args.noise_offset:
|
||||
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
|
||||
noise += args.noise_offset * torch.randn(
|
||||
(model_input.shape[0], model_input.shape[1], 1, 1), device=model_input.device
|
||||
)
|
||||
|
||||
bsz = model_input.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(
|
||||
@@ -1789,19 +1822,26 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if freeze_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"time_ids": add_time_ids,
|
||||
# "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1812,7 +1852,7 @@ def main(args):
|
||||
added_cond_kwargs=unet_added_conditions,
|
||||
).sample
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
@@ -1954,6 +1994,8 @@ def main(args):
|
||||
pipeline = StableDiffusionXLPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
vae=vae,
|
||||
tokenizer=tokenizer_one,
|
||||
tokenizer_2=tokenizer_two,
|
||||
text_encoder=accelerator.unwrap_model(text_encoder_one),
|
||||
text_encoder_2=accelerator.unwrap_model(text_encoder_two),
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
@@ -2033,6 +2075,11 @@ def main(args):
|
||||
text_encoder_lora_layers=text_encoder_lora_layers,
|
||||
text_encoder_2_lora_layers=text_encoder_2_lora_layers,
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embeddings_path = f"{args.output_dir}/{args.output_dir}_emb.safetensors"
|
||||
embedding_handler.save_embeddings(embeddings_path)
|
||||
|
||||
images = []
|
||||
if args.validation_prompt and args.num_validation_images > 0:
|
||||
# Final inference
|
||||
@@ -2068,6 +2115,25 @@ def main(args):
|
||||
# load attention processors
|
||||
pipeline.load_lora_weights(args.output_dir)
|
||||
|
||||
# load new tokens
|
||||
if args.train_text_encoder_ti:
|
||||
state_dict = load_file(embeddings_path)
|
||||
all_new_tokens = []
|
||||
for key, value in token_abstraction_dict.items():
|
||||
all_new_tokens.extend(value)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_l"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder,
|
||||
tokenizer=pipeline.tokenizer,
|
||||
)
|
||||
pipeline.load_textual_inversion(
|
||||
state_dict["clip_g"],
|
||||
token=all_new_tokens,
|
||||
text_encoder=pipeline.text_encoder_2,
|
||||
tokenizer=pipeline.tokenizer_2,
|
||||
)
|
||||
|
||||
# run inference
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) if args.seed else None
|
||||
@@ -2090,11 +2156,6 @@ def main(args):
|
||||
}
|
||||
)
|
||||
|
||||
if args.train_text_encoder_ti:
|
||||
embedding_handler.save_embeddings(
|
||||
f"{args.output_dir}/{args.output_dir}_emb.safetensors",
|
||||
)
|
||||
|
||||
# Conver to WebUI format
|
||||
lora_state_dict = load_file(f"{args.output_dir}/pytorch_lora_weights.safetensors")
|
||||
peft_state_dict = convert_all_state_dict_to_peft(lora_state_dict)
|
||||
|
||||
Reference in New Issue
Block a user