mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Examples] Support train_text_to_image_lora_sdxl.py (#4365)
* add train_text_to_image_lora_sdxl.py * add train_text_to_image_lora_sdxl.py * add test and minor fix * Update examples/text_to_image/README_sdxl.md Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * fix unwrap_model rule * add invisible-watermark in requirements * del invisible-watermark * Update examples/text_to_image/README_sdxl.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update examples/text_to_image/README_sdxl.md Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update examples/text_to_image/train_text_to_image_lora_sdxl.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * del comment & update readme --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -401,4 +401,8 @@ Thanks to [@isidentical](https://github.com/isidentical) for helping us on integ
|
||||
|
||||
### Known limitations specific to the Kohya-styled LoRAs
|
||||
|
||||
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
|
||||
* When images don't looks similar to other UIs such ComfyUI, it can be beacause of multiple reasons as explained [here](https://github.com/huggingface/diffusers/pull/4287/#issuecomment-1655110736).
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
We support fine-tuning of the UNet shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with DreamBooth and LoRA via the `train_text_to_image_lora_sdxl.py` script. Please refer to the docs [here](https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/README_sdxl.md).
|
||||
|
||||
@@ -1420,3 +1420,64 @@ class ExamplesTestsAccelerate(unittest.TestCase):
|
||||
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
|
||||
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
|
||||
)
|
||||
|
||||
def test_text_to_image_lora_sdxl(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
def test_text_to_image_lora_sdxl_with_text_encoder(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/text_to_image/train_text_to_image_lora_sdxl.py
|
||||
--pretrained_model_name_or_path hf-internal-testing/tiny-stable-diffusion-xl-pipe
|
||||
--dataset_name hf-internal-testing/dummy_image_text_data
|
||||
--resolution 64
|
||||
--train_batch_size 1
|
||||
--gradient_accumulation_steps 1
|
||||
--max_train_steps 2
|
||||
--learning_rate 5.0e-04
|
||||
--scale_lr
|
||||
--lr_scheduler constant
|
||||
--lr_warmup_steps 0
|
||||
--output_dir {tmpdir}
|
||||
--train_text_encoder
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
# save_pretrained smoke test
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.bin")))
|
||||
|
||||
# make sure the state_dict has the correct naming in the parameters.
|
||||
lora_state_dict = torch.load(os.path.join(tmpdir, "pytorch_lora_weights.bin"))
|
||||
is_lora = all("lora" in k for k in lora_state_dict.keys())
|
||||
self.assertTrue(is_lora)
|
||||
|
||||
# when not training the text encoder, all the parameters in the state dict should start
|
||||
# with `"unet"` or `"text_encoder"` or `"text_encoder_2"` in their names.
|
||||
keys = lora_state_dict.keys()
|
||||
starts_with_unet = all(
|
||||
k.startswith("unet") or k.startswith("text_encoder") or k.startswith("text_encoder_2") for k in keys
|
||||
)
|
||||
self.assertTrue(starts_with_unet)
|
||||
|
||||
@@ -316,3 +316,7 @@ xFormers training is not available for Flax/JAX.
|
||||
**Note**:
|
||||
|
||||
According to [this issue](https://github.com/huggingface/diffusers/issues/2234#issuecomment-1416931212), xFormers `v0.0.16` cannot be used for training in some GPUs. If you observe that problem, please install a development version as indicated in that comment.
|
||||
|
||||
## Stable Diffusion XL
|
||||
|
||||
We support fine-tuning of the UNet and Text Encoder shipped in [Stable Diffusion XL](https://huggingface.co/papers/2307.01952) with LoRA via the `train_text_to_image_lora_xl.py` script. Please refer to the docs [here](./README_sdxl.md).
|
||||
|
||||
134
examples/text_to_image/README_sdxl.md
Normal file
134
examples/text_to_image/README_sdxl.md
Normal file
@@ -0,0 +1,134 @@
|
||||
# LoRA training example for Stable Diffusion XL (SDXL)
|
||||
|
||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
|
||||
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
|
||||
- Previous pretrained weights are kept frozen so that model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
||||
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
|
||||
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
With LoRA, it's possible to fine-tune Stable Diffusion on a custom image-caption pair dataset
|
||||
on consumer GPUs like Tesla T4, Tesla V100.
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/text_to_image` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sdxl.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
### Training
|
||||
|
||||
First, you need to set up your development environment as is explained in the [installation section](#installing-the-dependencies). Make sure to set the `MODEL_NAME` and `DATASET_NAME` environment variables. Here, we will use [Stable Diffusion XL 1.0-base](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0) and the [Pokemons dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [Weights and Biases](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training to automatically log images.___**
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
```
|
||||
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so
|
||||
we need to be logged in and add the `--push_to_hub` flag.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Now we can start training!
|
||||
|
||||
```bash
|
||||
accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=1024 --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--num_train_epochs=2 --checkpointing_steps=500 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--seed=42 \
|
||||
--output_dir="sd-pokemon-model-lora-sdxl" \
|
||||
--validation_prompt="cute dragon creature" --report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
The above command will also run inference as fine-tuning progresses and log the results to Weights and Biases.
|
||||
|
||||
### Finetuning the text encoder and UNet
|
||||
|
||||
The script also allows you to finetune the `text_encoder` along with the `unet`.
|
||||
|
||||
🚨 Training the text encoder requires additional memory.
|
||||
|
||||
Pass the `--train_text_encoder` argument to the training script to enable finetuning the `text_encoder` and `unet`:
|
||||
|
||||
```bash
|
||||
accelerate launch train_text_to_image_lora_sdxl.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=1024 --random_flip \
|
||||
--train_batch_size=1 \
|
||||
--num_train_epochs=2 --checkpointing_steps=500 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--seed=42 \
|
||||
--output_dir="sd-pokemon-model-lora-sdxl-txt" \
|
||||
--train_text_encoder \
|
||||
--validation_prompt="cute dragon creature" --report_to="wandb" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
### Inference
|
||||
|
||||
Once you have trained a model using above command, the inference can be done simply using the `DiffusionPipeline` after loading the trained LoRA weights. You
|
||||
need to pass the `output_dir` for loading the LoRA weights which, in this case, is `sd-pokemon-model-lora-sdxl`.
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
model_path = "takuoko/sd-pokemon-model-lora-sdxl"
|
||||
pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
|
||||
pipe.to("cuda")
|
||||
pipe.load_lora_weights(model_path)
|
||||
|
||||
prompt = "A pokemon with green eyes and red legs."
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("pokemon.png")
|
||||
```
|
||||
6
examples/text_to_image/requirements_sdxl.txt
Normal file
6
examples/text_to_image/requirements_sdxl.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
1282
examples/text_to_image/train_text_to_image_lora_sdxl.py
Normal file
1282
examples/text_to_image/train_text_to_image_lora_sdxl.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user