1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-29 07:22:12 +03:00

[docs] Move text-to-image LoRA training from blog to docs (#2527)

* include text2image lora training in docs

* 🖍 apply feedback

* 🖍 minor edits
This commit is contained in:
Steven Liu
2023-03-06 13:45:07 -08:00
committed by GitHub
parent 9136be14a7
commit 62bea2df36
2 changed files with 156 additions and 120 deletions

View File

@@ -86,7 +86,7 @@
- local: training/text2image
title: Text-to-image
- local: training/lora
title: LoRA Support in Diffusers
title: Low-Rank Adaptation of Large Language Models (LoRA)
title: Training
- sections:
- local: conceptual/philosophy

View File

@@ -10,54 +10,151 @@ an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express o
specific language governing permissions and limitations under the License.
-->
# LoRA Support in Diffusers
# Low-Rank Adaptation of Large Language Models (LoRA)
Diffusers supports LoRA for faster fine-tuning of Stable Diffusion, allowing greater memory efficiency and easier portability.
[[open-in-colab]]
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in
[LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
<Tip warning={true}>
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition weight matrices (called **update matrices**)
to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not so prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA matrices are generally added to the attention layers of the original model and they control to which extent the model is adapted toward new training images via a `scale` parameter.
**__Note that the usage of LoRA is not just limited to attention layers. In the original LoRA work, the authors found out that just amending
the attention layers of a language model is sufficient to obtain good downstream performance with great efficiency. This is why, it's common
to just add the LoRA weights to the attention layers of a model.__**
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
<Tip>
LoRA allows us to achieve greater memory efficiency since the pretrained weights are kept frozen and only the LoRA weights are trained, thereby
allowing us to run fine-tuning on consumer GPUs like Tesla T4, RTX 3080 or even RTX 2080 Ti! One can get access to GPUs like T4 in the free
tiers of Kaggle Kernels and Google Colab Notebooks.
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
</Tip>
## Getting started with LoRA for fine-tuning
[Low-Rank Adaptation of Large Language Models (LoRA)](https://arxiv.org/abs/2106.09685) is a training method that accelerates the training of large models while consuming less memory. It adds pairs of rank-decomposition weight matrices (called **update matrices**) to existing weights, and **only** trains those newly added weights. This has a couple of advantages:
Stable Diffusion can be fine-tuned in different ways:
- Previous pretrained weights are kept frozen so the model is not as prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA matrices are generally added to the attention layers of the original model. 🧨 Diffusers provides the [`~diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method to load the LoRA weights into a model's attention layers. You can control the extent to which the model is adapted toward new training images via a `scale` parameter.
- The greater memory-efficiency allows you to run fine-tuning on consumer GPUs like the Tesla T4, RTX 3080 or even the RTX 2080 Ti! GPUs like the T4 are free and readily accessible in Kaggle or Google Colab notebooks.
* [Textual inversion](https://huggingface.co/docs/diffusers/main/en/training/text_inversion)
* [DreamBooth](https://huggingface.co/docs/diffusers/main/en/training/dreambooth)
* [Text2Image fine-tuning](https://huggingface.co/docs/diffusers/main/en/training/text2image)
<Tip>
We provide two end-to-end examples that show how to run fine-tuning with LoRA:
💡 LoRA is not only limited to attention layers. The authors found that amending
the attention layers of a language model is sufficient to obtain good downstream performance with great efficiency. This is why it's common to just add the LoRA weights to the attention layers of a model. Check out the [Using LoRA for efficient Stable Diffusion fine-tuning](https://huggingface.co/blog/lora) blog for more information about how LoRA works!
* [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora)
* [Text2Image](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora)
</Tip>
If you want to perform DreamBooth training with LoRA, for instance, you would run:
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository. 🧨 Diffusers now supports finetuning with LoRA for [text-to-image generation](https://github.com/huggingface/diffusers/tree/main/examples/text_to_image#training-with-lora) and [DreamBooth](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#training-with-low-rank-adaptation-of-large-language-models-lora). This guide will show you how to do both.
If you'd like to store or share your model with the community, login to your Hugging Face account (create [one](hf.co/join) if you don't have one already):
```bash
huggingface-cli login
```
## Text-to-image
Finetuning a model like Stable Diffusion, which has billions of parameters, can be slow and difficult. With LoRA, it is much easier and faster to finetune a diffusion model. It can run on hardware with as little as 11GB of GPU RAM without resorting to tricks such as 8-bit optimizers.
### Training[[text-to-image-training]]
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) on the [Pokémon BLIP captions](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions) dataset to generate your own Pokémon.
To start, make sure you have the `MODEL_NAME` and `DATASET_NAME` environment variables set. The `OUTPUT_DIR` and `HUB_MODEL_ID` variables are optional and specify where to save the model to on the Hub:
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export OUTPUT_DIR="/sddata/finetune/lora/pokemon"
export HUB_MODEL_ID="pokemon-lora"
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
```
There are some flags to be aware of before you start training:
* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py)):
```bash
accelerate launch --mixed_precision="fp16" train_text_to_image_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--dataloader_num_workers=8 \
--resolution=512 --center_crop --random_flip \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--max_train_steps=15000 \
--learning_rate=1e-04 \
--max_grad_norm=1 \
--lr_scheduler="cosine" --lr_warmup_steps=0 \
--output_dir=${OUTPUT_DIR} \
--push_to_hub \
--hub_model_id=${HUB_MODEL_ID} \
--report_to=wandb \
--checkpointing_steps=500 \
--validation_prompt="A pokemon with blue eyes." \
--seed=1337
```
### Inference[[text-to-image-inference]]
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`] and then the [`DPMSolverMultistepScheduler`]:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
>>> model_base = "runwayml/stable-diffusion-v1-5"
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
>>> pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
```
Load the LoRA weights from your finetuned model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
<Tip>
💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
</Tip>
```py
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model
>>> image = pipe(
... "A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}
... ).images[0]
# use the weights from the fully finetuned LoRA model
>>> image = pipe("A pokemon with blue eyes.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("blue_pokemon.png")
```
## DreamBooth
[DreamBooth](https://arxiv.org/abs/2208.12242) is a finetuning technique for personalizing a text-to-image model like Stable Diffusion to generate photorealistic images of a subject in different contexts, given a few images of the subject. However, DreamBooth is very sensitive to hyperparameters and it is easy to overfit. Some important hyperparameters to consider include those that affect the training time (learning rate, number of training steps), and inference time (number of steps, scheduler type).
<Tip>
💡 Take a look at the [Training Stable Diffusion with DreamBooth using 🧨 Diffusers](https://huggingface.co/blog/dreambooth) blog for an in-depth analysis of DreamBooth experiments and recommended settings.
</Tip>
### Training[[dreambooth-training]]
Let's finetune [`stable-diffusion-v1-5`](https://huggingface.co/runwayml/stable-diffusion-v1-5) with DreamBooth and LoRA with some 🐶 [dog images](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ). Download and save these images to a directory.
To start, make sure you have the `MODEL_NAME` and `INSTANCE_DIR` (path to directory containing images) environment variables set. The `OUTPUT_DIR` variables is optional and specifies where to save the model to on the Hub:
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
```
There are some flags to be aware of before you start training:
* `--push_to_hub` stores the trained LoRA embeddings on the Hub.
* `--report_to=wandb` reports and logs the training results to your Weights & Biases dashboard (as an example, take a look at this [report](https://wandb.ai/pcuenq/text2image-fine-tune/runs/b4k1w0tn?workspace=user-pcuenq)).
* `--learning_rate=1e-04`, you can afford to use a higher learning rate than you normally would with LoRA.
Now you're ready to launch the training (you can find the full training script [here](https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora.py)):
```bash
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
@@ -78,101 +175,40 @@ accelerate launch train_dreambooth_lora.py \
--push_to_hub
```
A similar process can be followed to fully fine-tune Stable Diffusion on a custom dataset using the
`examples/text_to_image/train_text_to_image_lora.py` script.
### Inference[[dreambooth-inference]]
Refer to the respective examples linked above to learn more.
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
```py
>>> import torch
>>> from diffusers import StableDiffusionPipeline
>>> model_base = "runwayml/stable-diffusion-v1-5"
>>> pipe = StableDiffusionPipeline.from_pretrained(model_base, torch_dtype=torch.float16)
```
Load the LoRA weights from your finetuned DreamBooth model *on top of the base model weights*, and then move the pipeline to a GPU for faster inference. When you merge the LoRA weights with the frozen pretrained model weights, you can optionally adjust how much of the weights to merge with the `scale` parameter:
<Tip>
When using LoRA we can use a much higher learning rate (typically 1e-4 as opposed to ~1e-6) compared to non-LoRA Dreambooth fine-tuning.
💡 A `scale` value of `0` is the same as not using your LoRA weights and you're only using the base model weights, and a `scale` value of `1` means you're only using the fully finetuned LoRA weights. Values between `0` and `1` interpolates between the two weights.
</Tip>
But there is no free lunch. For the given dataset and expected generation quality, you'd still need to experiment with
different hyperparameters. Here are some important ones:
* Training time
* Learning rate
* Number of training steps
* Inference time
* Number of steps
* Scheduler type
Additionally, you can follow [this blog](https://huggingface.co/blog/dreambooth) that documents some of our experimental
findings for performing DreamBooth training of Stable Diffusion.
When fine-tuning, the LoRA update matrices are only added to the attention layers. To enable this, we added new weight
loading functionalities. Their details are available [here](https://huggingface.co/docs/diffusers/main/en/api/loaders).
## Inference
Assuming you used the `examples/text_to_image/train_text_to_image_lora.py` to fine-tune Stable Diffusion on the [Pokemon
dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions), you can perform inference like so:
```py
from diffusers import StableDiffusionPipeline
import torch
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")
prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
image.save("pokemon.png")
```
Here are some example images you can expect:
<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/pokemon-collage.png"/>
[`sayakpaul/sd-model-finetuned-lora-t4`](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4) contains [LoRA fine-tuned update matrices](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4/blob/main/pytorch_lora_weights.bin)
which is only 3 MBs in size. During inference, the pre-trained Stable Diffusion checkpoints are loaded alongside these update
matrices and then they are combined to run inference.
You can use the [`huggingface_hub`](https://github.com/huggingface/huggingface_hub) library to retrieve the base model
from [`sayakpaul/sd-model-finetuned-lora-t4`](https://huggingface.co/sayakpaul/sd-model-finetuned-lora-t4) like so:
```py
from huggingface_hub.repocard import RepoCard
>>> pipe.unet.load_attn_procs(model_path)
>>> pipe.to("cuda")
# use half the weights from the LoRA finetuned model and half the weights from the base model
card = RepoCard.load("sayakpaul/sd-model-finetuned-lora-t4")
base_model = card.data.to_dict()["base_model"]
# 'CompVis/stable-diffusion-v1-4'
```
>>> image = pipe(
... "A picture of a sks dog in a bucket.",
... num_inference_steps=25,
... guidance_scale=7.5,
... cross_attention_kwargs={"scale": 0.5},
... ).images[0]
# use the weights from the fully finetuned LoRA model
And then you can use `pipe = StableDiffusionPipeline.from_pretrained(base_model, torch_dtype=torch.float16)`.
This is especially useful when you don't want to hardcode the base model identifier during initializing the `StableDiffusionPipeline`.
Inference for DreamBooth training remains the same. Check
[this section](https://github.com/huggingface/diffusers/tree/main/examples/dreambooth#inference-1) for more details.
### Merging LoRA with original model
When performing inference, you can merge the trained LoRA weights with the frozen pre-trained model weights, to interpolate between the original model's inference result (as if no fine-tuning had occurred) and the fully fine-tuned version.
You can adjust the merging ratio with a parameter called α (alpha) in the paper, or `scale` in our implementation. You can tweak it with the following code, that passes `scale` as `cross_attention_kwargs` in the pipeline call:
```py
from diffusers import StableDiffusionPipeline
import torch
model_path = "sayakpaul/sd-model-finetuned-lora-t4"
pipe = StableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", torch_dtype=torch.float16)
pipe.unet.load_attn_procs(model_path)
pipe.to("cuda")
prompt = "A pokemon with blue eyes."
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5, cross_attention_kwargs={"scale": 0.5}).images[0]
image.save("pokemon.png")
```
A value of `0` is the same as _not_ using the LoRA weights, whereas `1` means only the LoRA fine-tuned weights will be used. Values between 0 and 1 will interpolate between the two versions.
## Known limitations
* Currently, we only support LoRA for the attention layers of [`UNet2DConditionModel`](https://huggingface.co/docs/diffusers/main/en/api/models#diffusers.UNet2DConditionModel).
>>> image = pipe("A picture of a sks dog in a bucket.", num_inference_steps=25, guidance_scale=7.5).images[0]
>>> image.save("bucket-dog.png")
```