mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
adding custom diffusion training to diffusers examples (#3031)
* diffusers==0.14.0 update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion update * custom diffusion * custom diffusion * custom diffusion * custom diffusion * custom diffusion * apply formatting and get rid of bare except. * refactor readme and other minor changes. * misc refactor. * fix: repo_id issue and loaders logging bug. * fix: save_model_card. * fix: save_model_card. * fix: save_model_card. * add: doc entry. * refactor doc,. * custom diffusion * custom diffusion * custom diffusion * apply style. * remove tralining whitespace. * fix: toctree entry. * remove unnecessary print. * custom diffusion * custom diffusion * custom diffusion test * custom diffusion xformer update * custom diffusion xformer update * custom diffusion xformer update --------- Co-authored-by: Nupur Kumari <nupurkumari@Nupurs-MacBook-Pro.local> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Nupur Kumari <nupurkumari@nupurs-mbp.wifi.local.cmu.edu>
This commit is contained in:
@@ -74,6 +74,8 @@
|
||||
title: ControlNet
|
||||
- local: training/instructpix2pix
|
||||
title: InstructPix2Pix Training
|
||||
- local: training/custom_diffusion
|
||||
title: Custom Diffusion
|
||||
title: Training
|
||||
- sections:
|
||||
- local: using-diffusers/rl
|
||||
|
||||
287
docs/source/en/training/custom_diffusion.mdx
Normal file
287
docs/source/en/training/custom_diffusion.mdx
Normal file
@@ -0,0 +1,287 @@
|
||||
<!--Copyright 2023 Custom Diffusion authors The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to("cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
@@ -39,6 +39,8 @@ Training examples show how to pretrain or fine-tune diffusion models for a varie
|
||||
- [Dreambooth](./dreambooth)
|
||||
- [LoRA Support](./lora)
|
||||
- [ControlNet](./controlnet)
|
||||
- [InstructPix2Pix](./instructpix2pix)
|
||||
- [Custom Diffusion](./custom_diffusion)
|
||||
|
||||
If possible, please [install xFormers](../optimization/xformers) for memory efficient attention. This could help make your training faster and less memory intensive.
|
||||
|
||||
@@ -50,6 +52,8 @@ If possible, please [install xFormers](../optimization/xformers) for memory effi
|
||||
| [**Dreambooth**](./dreambooth) | ✅ | - | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/sd_dreambooth_training.ipynb)
|
||||
| [**Training with LoRA**](./lora) | ✅ | - | - |
|
||||
| [**ControlNet**](./controlnet) | ✅ | ✅ | - |
|
||||
| [**InstructPix2Pix**](./instructpix2pix) | ✅ | ✅ | - |
|
||||
| [**Custom Diffusion**](./custom_diffusion) | ✅ | ✅ | - |
|
||||
|
||||
## Community
|
||||
|
||||
|
||||
280
examples/custom_diffusion/README.md
Normal file
280
examples/custom_diffusion/README.md
Normal file
@@ -0,0 +1,280 @@
|
||||
# Custom Diffusion training example
|
||||
|
||||
[Custom Diffusion](https://arxiv.org/abs/2212.04488) is a method to customize text-to-image models like Stable Diffusion given just a few (4~5) images of a subject.
|
||||
The `train_custom_diffusion.py` script shows how to implement the training procedure and adapt it for stable diffusion.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the example folder and run
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
pip install clip-retrieval
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell e.g. a notebook
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
### Cat example 😺
|
||||
|
||||
Now let's get our dataset. Download dataset from [here](https://www.cs.cmu.edu/~custom-diffusion/assets/data.zip) and unzip it.
|
||||
|
||||
We also collect 200 real images using `clip-retrieval` which are combined with the target images in the training dataset as a regularization. This prevents overfitting to the the given target image. The following flags enable the regularization `with_prior_preservation`, `real_prior` with `prior_loss_weight=1.`.
|
||||
The `class_prompt` should be the category name same as target image. The collected real images are with text captions similar to the `class_prompt`. The retrieved image are saved in `class_data_dir`. You can disable `real_prior` to use generated images as regularization. To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt cat --class_data_dir real_reg/samples_cat --num_class_images 200
|
||||
```
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="./data/cat"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>"
|
||||
```
|
||||
|
||||
**Use `--enable_xformers_memory_efficient_attention` for faster training with lower VRAM requirement (16GB per GPU). Follow [this guide](https://github.com/facebookresearch/xformers) for installation instructions.**
|
||||
|
||||
To track your experiments using Weights and Biases (`wandb`) and to save intermediate results (whcih we HIGHLY recommend), follow these steps:
|
||||
|
||||
* Install `wandb`: `pip install wandb`.
|
||||
* Authorize: `wandb login`.
|
||||
* Then specify a `validation_prompt` and set `report_to` to `wandb` while launching training. You can also configure the following related arguments:
|
||||
* `num_validation_images`
|
||||
* `validation_steps`
|
||||
|
||||
Here is an example command:
|
||||
|
||||
```bash
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_cat/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="cat" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> cat" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=250 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>" \
|
||||
--validation_prompt="<new1> cat sitting in a bucket" \
|
||||
--report_to="wandb"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/26ghrcau) where you can check out the intermediate results along with other training details.
|
||||
|
||||
If you specify `--push_to_hub`, the learned parameters will be pushed to a repository on the Hugging Face Hub. Here is an [example repository](https://huggingface.co/sayakpaul/custom-diffusion-cat).
|
||||
|
||||
### Training on multiple concepts 🐱🪵
|
||||
|
||||
Provide a [json](https://github.com/adobe-research/custom-diffusion/blob/main/assets/concept_list.json) file with the info about each concept, similar to [this](https://github.com/ShivamShrirao/diffusers/blob/main/examples/dreambooth/train_dreambooth.py).
|
||||
|
||||
To collect the real images run this command for each concept in the json file.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt {} --class_data_dir {} --num_class_images 200
|
||||
```
|
||||
|
||||
And then we're ready to start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--concepts_list=./concept_list.json \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=1e-5 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--num_class_images=200 \
|
||||
--scale_lr --hflip \
|
||||
--modifier_token "<new1>+<new2>"
|
||||
```
|
||||
|
||||
Here is an example [Weights and Biases page](https://wandb.ai/sayakpaul/custom-diffusion/runs/3990tzkg) where you can check out the intermediate results along with other training details.
|
||||
|
||||
### Training on human faces
|
||||
|
||||
For fine-tuning on human faces we found the following configuration to work better: `learning_rate=5e-6`, `max_train_steps=1000 to 2000`, and `freeze_model=crossattn` with at least 15-20 images.
|
||||
|
||||
To collect the real images use this command first before training.
|
||||
|
||||
```bash
|
||||
pip install clip-retrieval
|
||||
python retrieve.py --class_prompt person --class_data_dir real_reg/samples_person --num_class_images 200
|
||||
```
|
||||
|
||||
Then start training!
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
export INSTANCE_DIR="path-to-images"
|
||||
|
||||
accelerate launch train_custom_diffusion.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--class_data_dir=./real_reg/samples_person/ \
|
||||
--with_prior_preservation --real_prior --prior_loss_weight=1.0 \
|
||||
--class_prompt="person" --num_class_images=200 \
|
||||
--instance_prompt="photo of a <new1> person" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=2 \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=1000 \
|
||||
--scale_lr --hflip --noaug \
|
||||
--freeze_model crossattn \
|
||||
--modifier_token "<new1>" \
|
||||
--enable_xformers_memory_efficient_attention
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
Once you have trained a model using the above command, you can run inference using the below command. Make sure to include the `modifier token` (e.g. \<new1\> in above example) in your prompt.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(
|
||||
"CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16
|
||||
).to("cuda")
|
||||
pipe.unet.load_attn_procs(
|
||||
"path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin"
|
||||
)
|
||||
pipe.load_textual_inversion("path-to-save-model", weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
It's possible to directly load these parameters from a Hub repository:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
|
||||
image = pipe(
|
||||
"<new1> cat sitting in a bucket",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("cat.png")
|
||||
```
|
||||
|
||||
Here is an example of performing inference with multiple concepts:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from huggingface_hub.repocard import RepoCard
|
||||
from diffusers import DiffusionPipeline
|
||||
|
||||
model_id = "sayakpaul/custom-diffusion-cat-wooden-pot"
|
||||
card = RepoCard.load(model_id)
|
||||
base_model_id = card.data.to_dict()["base_model"]
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained(base_model_id, torch_dtype=torch.float16).to(
|
||||
"cuda")
|
||||
pipe.unet.load_attn_procs(model_id, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new1>.bin")
|
||||
pipe.load_textual_inversion(model_id, weight_name="<new2>.bin")
|
||||
|
||||
image = pipe(
|
||||
"the <new1> cat sculpture in the style of a <new2> wooden pot",
|
||||
num_inference_steps=100,
|
||||
guidance_scale=6.0,
|
||||
eta=1.0,
|
||||
).images[0]
|
||||
image.save("multi-subject.png")
|
||||
```
|
||||
|
||||
Here, `cat` and `wooden pot` refer to the multiple concepts.
|
||||
|
||||
### Inference from a training checkpoint
|
||||
|
||||
You can also perform inference from one of the complete checkpoint saved during the training process, if you used the `--checkpointing_steps` argument.
|
||||
|
||||
TODO.
|
||||
|
||||
## Set grads to none
|
||||
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
|
||||
|
||||
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
|
||||
|
||||
## Experimental results
|
||||
You can refer to [our webpage](https://www.cs.cmu.edu/~custom-diffusion/) that discusses our experiments in detail.
|
||||
6
examples/custom_diffusion/requirements.txt
Normal file
6
examples/custom_diffusion/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
accelerate
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
87
examples/custom_diffusion/retrieve.py
Normal file
87
examples/custom_diffusion/retrieve.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2023 Custom Diffusion authors. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import os
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
import requests
|
||||
from clip_retrieval.clip_client import ClipClient
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
def retrieve(class_prompt, class_data_dir, num_class_images):
|
||||
factor = 1.5
|
||||
num_images = int(factor * num_class_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service", indice_name="laion_400m", num_images=num_images, aesthetic_weight=0.1
|
||||
)
|
||||
|
||||
os.makedirs(f"{class_data_dir}/images", exist_ok=True)
|
||||
if len(list(Path(f"{class_data_dir}/images").iterdir())) >= num_class_images:
|
||||
return
|
||||
|
||||
while True:
|
||||
class_images = client.query(text=class_prompt)
|
||||
if len(class_images) >= factor * num_class_images or num_images > 1e4:
|
||||
break
|
||||
else:
|
||||
num_images = int(factor * num_images)
|
||||
client = ClipClient(
|
||||
url="https://knn.laion.ai/knn-service",
|
||||
indice_name="laion_400m",
|
||||
num_images=num_images,
|
||||
aesthetic_weight=0.1,
|
||||
)
|
||||
|
||||
count = 0
|
||||
total = 0
|
||||
pbar = tqdm(desc="downloading real regularization images", total=num_class_images)
|
||||
|
||||
with open(f"{class_data_dir}/caption.txt", "w") as f1, open(f"{class_data_dir}/urls.txt", "w") as f2, open(
|
||||
f"{class_data_dir}/images.txt", "w"
|
||||
) as f3:
|
||||
while total < num_class_images:
|
||||
images = class_images[count]
|
||||
count += 1
|
||||
try:
|
||||
img = requests.get(images["url"])
|
||||
if img.status_code == 200:
|
||||
_ = Image.open(BytesIO(img.content))
|
||||
with open(f"{class_data_dir}/images/{total}.jpg", "wb") as f:
|
||||
f.write(img.content)
|
||||
f1.write(images["caption"] + "\n")
|
||||
f2.write(images["url"] + "\n")
|
||||
f3.write(f"{class_data_dir}/images/{total}.jpg" + "\n")
|
||||
total += 1
|
||||
pbar.update(1)
|
||||
else:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
return
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser("", add_help=False)
|
||||
parser.add_argument("--class_prompt", help="text prompt to retrieve images", required=True, type=str)
|
||||
parser.add_argument("--class_data_dir", help="path to save images", required=True, type=str)
|
||||
parser.add_argument("--num_class_images", help="number of images to download", default=200, type=int)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
retrieve(args.class_prompt, args.class_data_dir, args.num_class_images)
|
||||
1289
examples/custom_diffusion/train_custom_diffusion.py
Normal file
1289
examples/custom_diffusion/train_custom_diffusion.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -221,6 +221,30 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
|
||||
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
|
||||
|
||||
def test_custom_diffusion(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/custom_diffusion/train_custom_diffusion.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-pipe
|
||||
--instance_data_dir docs/source/en/imgs
|
||||
--instance_prompt <new1>
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 1.0e-05
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--modifier_token <new1>
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_custom_diffusion_weights.bin")))
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "<new1>.bin")))
|
||||
|
||||
def test_text_to_image(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
|
||||
@@ -19,7 +19,11 @@ from typing import Callable, Dict, List, Optional, Union
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .models.attention_processor import LoRAAttnProcessor
|
||||
from .models.attention_processor import (
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
LoRAAttnProcessor,
|
||||
)
|
||||
from .utils import (
|
||||
DIFFUSERS_CACHE,
|
||||
HF_HUB_OFFLINE,
|
||||
@@ -48,6 +52,9 @@ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
||||
TEXT_INVERSION_NAME = "learned_embeds.bin"
|
||||
TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
|
||||
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
||||
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
@@ -215,6 +222,7 @@ class UNet2DConditionLoadersMixin:
|
||||
attn_processors = {}
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
@@ -231,9 +239,38 @@ class UNet2DConditionLoadersMixin:
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
elif is_custom_diffusion:
|
||||
custom_diffusion_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
if len(value) == 0:
|
||||
custom_diffusion_grouped_dict[key] = {}
|
||||
else:
|
||||
if "to_out" in key:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
else:
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
||||
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in custom_diffusion_grouped_dict.items():
|
||||
if len(value_dict) == 0:
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
||||
)
|
||||
else:
|
||||
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
||||
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
||||
attn_processors[key] = CustomDiffusionAttnProcessor(
|
||||
train_kv=True,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
else:
|
||||
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
||||
raise ValueError(
|
||||
f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
|
||||
)
|
||||
|
||||
# set correct dtype & device
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
@@ -287,16 +324,31 @@ class UNet2DConditionLoadersMixin:
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
is_custom_diffusion = any(
|
||||
isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
for (_, x) in self.attn_processors.items()
|
||||
)
|
||||
if is_custom_diffusion:
|
||||
model_to_save = AttnProcsLayers(
|
||||
{
|
||||
y: x
|
||||
for (y, x) in self.attn_processors.items()
|
||||
if isinstance(x, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor))
|
||||
}
|
||||
)
|
||||
state_dict = model_to_save.state_dict()
|
||||
for name, attn in self.attn_processors.items():
|
||||
if len(attn.state_dict()) == 0:
|
||||
state_dict[name] = {}
|
||||
else:
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
if weight_name is None:
|
||||
if safe_serialization:
|
||||
weight_name = LORA_WEIGHT_NAME_SAFE
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
||||
else:
|
||||
weight_name = LORA_WEIGHT_NAME
|
||||
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weight_name))
|
||||
|
||||
@@ -149,6 +149,9 @@ class Attention(nn.Module):
|
||||
is_lora = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (LoRAAttnProcessor, LoRAXFormersAttnProcessor)
|
||||
)
|
||||
is_custom_diffusion = hasattr(self, "processor") and isinstance(
|
||||
self.processor, (CustomDiffusionAttnProcessor, CustomDiffusionXFormersAttnProcessor)
|
||||
)
|
||||
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
@@ -192,6 +195,17 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionXFormersAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
attention_op=attention_op,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = XFormersAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
@@ -203,6 +217,16 @@ class Attention(nn.Module):
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
processor.to(self.processor.to_q_lora.up.weight.device)
|
||||
elif is_custom_diffusion:
|
||||
processor = CustomDiffusionAttnProcessor(
|
||||
train_kv=self.processor.train_kv,
|
||||
train_q_out=self.processor.train_q_out,
|
||||
hidden_size=self.processor.hidden_size,
|
||||
cross_attention_dim=self.processor.cross_attention_dim,
|
||||
)
|
||||
processor.load_state_dict(self.processor.state_dict())
|
||||
if hasattr(self.processor, "to_k_custom_diffusion"):
|
||||
processor.to(self.processor.to_k_custom_diffusion.weight.device)
|
||||
else:
|
||||
processor = AttnProcessor()
|
||||
|
||||
@@ -459,6 +483,84 @@ class LoRAAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=True,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnAddedKVProcessor:
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
@@ -699,6 +801,91 @@ class LoRAXFormersAttnProcessor(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDiffusionXFormersAttnProcessor(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
train_kv=True,
|
||||
train_q_out=False,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
out_bias=True,
|
||||
dropout=0.0,
|
||||
attention_op: Optional[Callable] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.train_kv = train_kv
|
||||
self.train_q_out = train_q_out
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.attention_op = attention_op
|
||||
|
||||
# `_custom_diffusion` id for easy serialization and loading.
|
||||
if self.train_kv:
|
||||
self.to_k_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
self.to_v_custom_diffusion = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
if self.train_q_out:
|
||||
self.to_q_custom_diffusion = nn.Linear(hidden_size, hidden_size, bias=False)
|
||||
self.to_out_custom_diffusion = nn.ModuleList([])
|
||||
self.to_out_custom_diffusion.append(nn.Linear(hidden_size, hidden_size, bias=out_bias))
|
||||
self.to_out_custom_diffusion.append(nn.Dropout(dropout))
|
||||
|
||||
def __call__(self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if self.train_q_out:
|
||||
query = self.to_q_custom_diffusion(hidden_states)
|
||||
else:
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
crossattn = False
|
||||
encoder_hidden_states = hidden_states
|
||||
else:
|
||||
crossattn = True
|
||||
if attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
if self.train_kv:
|
||||
key = self.to_k_custom_diffusion(encoder_hidden_states)
|
||||
value = self.to_v_custom_diffusion(encoder_hidden_states)
|
||||
else:
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
if crossattn:
|
||||
detach = torch.ones_like(key)
|
||||
detach[:, :1, :] = detach[:, :1, :] * 0.0
|
||||
key = detach * key + (1 - detach) * key.detach()
|
||||
value = detach * value + (1 - detach) * value.detach()
|
||||
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
if self.train_q_out:
|
||||
# linear proj
|
||||
hidden_states = self.to_out_custom_diffusion[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = self.to_out_custom_diffusion[1](hidden_states)
|
||||
else:
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SlicedAttnProcessor:
|
||||
def __init__(self, slice_size):
|
||||
self.slice_size = slice_size
|
||||
@@ -834,4 +1021,6 @@ AttentionProcessor = Union[
|
||||
AttnAddedKVProcessor2_0,
|
||||
LoRAAttnProcessor,
|
||||
LoRAXFormersAttnProcessor,
|
||||
CustomDiffusionAttnProcessor,
|
||||
CustomDiffusionXFormersAttnProcessor,
|
||||
]
|
||||
|
||||
@@ -22,7 +22,7 @@ import torch
|
||||
from parameterized import parameterized
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.models.attention_processor import CustomDiffusionAttnProcessor, LoRAAttnProcessor
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
@@ -68,6 +68,55 @@ def create_lora_layers(model, mock_weights: bool = True):
|
||||
return lora_attn_procs
|
||||
|
||||
|
||||
def create_custom_diffusion_layers(model, mock_weights: bool = True):
|
||||
train_kv = True
|
||||
train_q_out = True
|
||||
custom_diffusion_attn_procs = {}
|
||||
|
||||
st = model.state_dict()
|
||||
for name, _ in model.attn_processors.items():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
layer_name = name.split(".processor")[0]
|
||||
weights = {
|
||||
"to_k_custom_diffusion.weight": st[layer_name + ".to_k.weight"],
|
||||
"to_v_custom_diffusion.weight": st[layer_name + ".to_v.weight"],
|
||||
}
|
||||
if train_q_out:
|
||||
weights["to_q_custom_diffusion.weight"] = st[layer_name + ".to_q.weight"]
|
||||
weights["to_out_custom_diffusion.0.weight"] = st[layer_name + ".to_out.0.weight"]
|
||||
weights["to_out_custom_diffusion.0.bias"] = st[layer_name + ".to_out.0.bias"]
|
||||
if cross_attention_dim is not None:
|
||||
custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(
|
||||
train_kv=train_kv,
|
||||
train_q_out=train_q_out,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
).to(model.device)
|
||||
custom_diffusion_attn_procs[name].load_state_dict(weights)
|
||||
if mock_weights:
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
custom_diffusion_attn_procs[name].to_k_custom_diffusion.weight += 1
|
||||
custom_diffusion_attn_procs[name].to_v_custom_diffusion.weight += 1
|
||||
else:
|
||||
custom_diffusion_attn_procs[name] = CustomDiffusionAttnProcessor(
|
||||
train_kv=False,
|
||||
train_q_out=False,
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
)
|
||||
del st
|
||||
return custom_diffusion_attn_procs
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@@ -569,6 +618,96 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
def test_custom_diffusion_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
|
||||
# make sure we can set a list of attention processors
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
model.to(torch_device)
|
||||
|
||||
# test that attn processors can be set to itself
|
||||
model.set_attn_processor(model.attn_processors)
|
||||
|
||||
with torch.no_grad():
|
||||
sample2 = model(**inputs_dict).sample
|
||||
|
||||
assert (sample1 - sample2).abs().max() < 1e-4
|
||||
|
||||
def test_custom_diffusion_save_load(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_custom_diffusion_weights.bin")))
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname, weight_name="pytorch_custom_diffusion_weights.bin")
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# custom diffusion and no custom diffusion should be the same
|
||||
assert (sample - old_sample).abs().max() < 1e-4
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_custom_diffusion_xformers_on_off(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
custom_diffusion_attn_procs = create_custom_diffusion_layers(model, mock_weights=False)
|
||||
model.set_attn_processor(custom_diffusion_attn_procs)
|
||||
|
||||
# default
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict).sample
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
on_sample = model(**inputs_dict).sample
|
||||
|
||||
model.disable_xformers_memory_efficient_attention()
|
||||
off_sample = model(**inputs_dict).sample
|
||||
|
||||
assert (sample - on_sample).abs().max() < 1e-4
|
||||
assert (sample - off_sample).abs().max() < 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
|
||||
Reference in New Issue
Block a user