mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
* login * more logins * uploads * missed login * another missed login * downloads * examples and more logins * fix * setup * Apply style fixes * fix * Apply style fixes
142 lines
5.4 KiB
Markdown
142 lines
5.4 KiB
Markdown
# ControlNet training example for Stable Diffusion XL (SDXL)
|
|
|
|
The `train_controlnet_sdxl.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion XL](https://huggingface.co/papers/2307.01952).
|
|
|
|
## Running locally with PyTorch
|
|
|
|
### Installing the dependencies
|
|
|
|
Before running the scripts, make sure to install the library's training dependencies:
|
|
|
|
**Important**
|
|
|
|
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
|
|
|
```bash
|
|
git clone https://github.com/huggingface/diffusers
|
|
cd diffusers
|
|
pip install -e .
|
|
```
|
|
|
|
Then cd in the `examples/controlnet` folder and run
|
|
```bash
|
|
pip install -r requirements_sdxl.txt
|
|
```
|
|
|
|
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
|
|
|
```bash
|
|
accelerate config
|
|
```
|
|
|
|
Or for a default accelerate configuration without answering questions about your environment
|
|
|
|
```bash
|
|
accelerate config default
|
|
```
|
|
|
|
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
|
|
|
```python
|
|
from accelerate.utils import write_basic_config
|
|
write_basic_config()
|
|
```
|
|
|
|
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
|
|
|
## Circle filling dataset
|
|
|
|
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
|
|
|
## Training
|
|
|
|
Our training examples use two test conditioning images. They can be downloaded by running
|
|
|
|
```sh
|
|
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
|
|
|
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
|
```
|
|
|
|
Then run `hf auth login` to log into your Hugging Face account. This is needed to be able to push the trained ControlNet parameters to Hugging Face Hub.
|
|
|
|
```bash
|
|
export MODEL_DIR="stabilityai/stable-diffusion-xl-base-1.0"
|
|
export OUTPUT_DIR="path to save model"
|
|
|
|
accelerate launch train_controlnet_sdxl.py \
|
|
--pretrained_model_name_or_path=$MODEL_DIR \
|
|
--output_dir=$OUTPUT_DIR \
|
|
--dataset_name=fusing/fill50k \
|
|
--mixed_precision="fp16" \
|
|
--resolution=1024 \
|
|
--learning_rate=1e-5 \
|
|
--max_train_steps=15000 \
|
|
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
|
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
|
--validation_steps=100 \
|
|
--train_batch_size=1 \
|
|
--gradient_accumulation_steps=4 \
|
|
--report_to="wandb" \
|
|
--seed=42 \
|
|
--push_to_hub
|
|
```
|
|
|
|
To better track our training experiments, we're using the following flags in the command above:
|
|
|
|
* `report_to="wandb` will ensure the training runs are tracked on Weights and Biases. To use it, be sure to install `wandb` with `pip install wandb`.
|
|
* `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
|
|
|
Our experiments were conducted on a single 40GB A100 GPU.
|
|
|
|
### Inference
|
|
|
|
Once training is done, we can perform inference like so:
|
|
|
|
```python
|
|
from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
|
|
from diffusers.utils import load_image
|
|
import torch
|
|
|
|
base_model_path = "stabilityai/stable-diffusion-xl-base-1.0"
|
|
controlnet_path = "path to controlnet"
|
|
|
|
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
|
base_model_path, controlnet=controlnet, torch_dtype=torch.float16
|
|
)
|
|
|
|
# speed up diffusion process with faster scheduler and memory optimization
|
|
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
|
# remove following line if xformers is not installed or when using Torch 2.0.
|
|
pipe.enable_xformers_memory_efficient_attention()
|
|
# memory optimization.
|
|
pipe.enable_model_cpu_offload()
|
|
|
|
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
|
|
prompt = "pale golden rod circle with old lace background"
|
|
|
|
# generate image
|
|
generator = torch.manual_seed(0)
|
|
image = pipe(
|
|
prompt, num_inference_steps=20, generator=generator, image=control_image
|
|
).images[0]
|
|
image.save("./output.png")
|
|
```
|
|
|
|
## Notes
|
|
|
|
### Specifying a better VAE
|
|
|
|
SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of an alternative VAE (such as [`madebyollin/sdxl-vae-fp16-fix`](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
|
|
|
|
If you're using this VAE during training, you need to ensure you're using it during inference too. You do so by:
|
|
|
|
```diff
|
|
+ vae = AutoencoderKL.from_pretrained(vae_path_or_repo_id, torch_dtype=torch.float16)
|
|
controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
|
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
|
|
base_model_path, controlnet=controlnet, torch_dtype=torch.float16,
|
|
+ vae=vae,
|
|
)
|