mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into cogvideox-lora-and-training
This commit is contained in:
152
examples/controlnet/README_sd3.md
Normal file
152
examples/controlnet/README_sd3.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# ControlNet training example for Stable Diffusion 3 (SD3)
|
||||
|
||||
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/controlnet` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sd3.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Circle filling dataset
|
||||
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
||||
Please download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder.
|
||||
|
||||
## Training
|
||||
|
||||
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
|
||||
> [!NOTE]
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run the following commands to train a ControlNet model.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
export OUTPUT_DIR="sd3-controlnet-out"
|
||||
|
||||
accelerate launch train_controlnet_sd3.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_data_dir="fill50k" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--validation_steps=100 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
|
||||
|
||||
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
base_model_path, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda", torch.float16)
|
||||
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, control_image=control_image
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### GPU usage
|
||||
|
||||
SD3 is a large model and requires a lot of GPU memory.
|
||||
We recommend using one GPU with at least 80GB of memory.
|
||||
Make sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate).
|
||||
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 500 steps with batch size 8
|
||||
|
||||
| | |
|
||||
|-------------------|:-------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
|
||||
#### After 6500 steps with batch size 8:
|
||||
|
||||
| | |
|
||||
|-------------------|:-------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
8
examples/controlnet/requirements_sd3.txt
Normal file
8
examples/controlnet/requirements_sd3.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
datasets
|
||||
wandb
|
||||
@@ -115,3 +115,24 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetSD3(ExamplesTestsAccelerate):
|
||||
def test_controlnet_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_sd3.py
|
||||
--pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
1415
examples/controlnet/train_controlnet_sd3.py
Normal file
1415
examples/controlnet/train_controlnet_sd3.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1597,6 +1597,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
@@ -1606,6 +1607,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
)
|
||||
|
||||
|
||||
167
examples/research_projects/pytorch_xla/README.md
Normal file
167
examples/research_projects/pytorch_xla/README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
|
||||
|
||||
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
|
||||
|
||||
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
|
||||
|
||||
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
|
||||
where we shard the input batches over the TPU devices.
|
||||
|
||||
As of 9-11-2024, these are some expected step times.
|
||||
|
||||
| accelerator | global batch size | step time (seconds) |
|
||||
| ----------- | ----------------- | --------- |
|
||||
| v5p-128 | 1024 | 0.245 |
|
||||
| v5p-256 | 2048 | 0.234 |
|
||||
| v5p-512 | 4096 | 0.2498 |
|
||||
|
||||
## Create TPU
|
||||
|
||||
To create a TPU on Google Cloud first set these environment variables:
|
||||
|
||||
```bash
|
||||
export TPU_NAME=<tpu-name>
|
||||
export PROJECT_ID=<project-id>
|
||||
export ZONE=<google-cloud-zone>
|
||||
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
|
||||
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
|
||||
```
|
||||
|
||||
Then run the create TPU command:
|
||||
```bash
|
||||
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
|
||||
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
|
||||
--reserved
|
||||
```
|
||||
|
||||
You can also use other ways to reserve TPUs like GKE or queued resources.
|
||||
|
||||
## Setup TPU environment
|
||||
|
||||
Install PyTorch and PyTorch/XLA nightly versions:
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
'
|
||||
```
|
||||
|
||||
Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
|
||||
--command='python3 -c "import torch; import torch_xla;"'
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
git checkout main
|
||||
cd examples/research_projects/pytorch_xla
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install pillow --upgrade
|
||||
cd ../../..
|
||||
pip3 install .'
|
||||
```
|
||||
|
||||
## Run the training job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This script only trains the unet part of the network. The VAE and text encoder
|
||||
are fixed.
|
||||
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=1
|
||||
export PROFILE_DIR=/tmp/
|
||||
export CACHE_DIR=/tmp/
|
||||
export DATASET_NAME=lambdalabs/naruto-blip-captions
|
||||
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
|
||||
export TRAIN_STEPS=50
|
||||
export OUTPUT_DIR=/tmp/trained-model/
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
|
||||
```
|
||||
|
||||
### Environment Envs Explained
|
||||
|
||||
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
|
||||
* `PROFILE_DIR`: Specify where to put the profiling results.
|
||||
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
|
||||
* `DATASET_NAME`: Dataset to train the model.
|
||||
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
|
||||
* `TRAIN_STEPS`: Total number of training steps to run the training for.
|
||||
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
|
||||
|
||||
## Run inference using the output model
|
||||
|
||||
To run inference using the output, you can simply load the model and pass it
|
||||
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
|
||||
|
||||
```bash
|
||||
export CACHE_DIR=/tmp/
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
from time import time
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
||||
if CACHE_DIR:
|
||||
xr.initialize_cache(CACHE_DIR, readonly=False)
|
||||
|
||||
def main():
|
||||
device = xm.xla_device()
|
||||
model_path = "jffacevedo/pxla_trained_model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to(device)
|
||||
prompt = ["A naruto with green eyes and red legs."]
|
||||
start = time()
|
||||
print("compiling...")
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
print(f"compile time: {time() - start}")
|
||||
print("generate...")
|
||||
start = time()
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
print(f"generation time (after compile) : {time() - start}")
|
||||
image.save("naruto.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
Expected Results:
|
||||
|
||||
```bash
|
||||
compiling...
|
||||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
|
||||
compile time: 720.656970500946
|
||||
generate...
|
||||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
|
||||
generation time (after compile) : 1.8461642265319824
|
||||
8
examples/research_projects/pytorch_xla/requirements.txt
Normal file
8
examples/research_projects/pytorch_xla/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets>=2.19.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -0,0 +1,669 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
pass
|
||||
|
||||
PROFILE_DIR = os.environ.get("PROFILE_DIR", None)
|
||||
CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
||||
if CACHE_DIR:
|
||||
xr.initialize_cache(CACHE_DIR, readonly=False)
|
||||
xr.use_spmd()
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
||||
}
|
||||
PORT = 9012
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
model_description = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. \n
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
from time import time
|
||||
from typing import Tuple
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def main(args):
|
||||
device = xm.xla_device()
|
||||
model_path = <output_dir>
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to(device)
|
||||
prompt = ["A naruto with green eyes and red legs."]
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("naruto.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* Steps: {args.max_train_steps}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
class TrainSD:
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
weight_dtype,
|
||||
device,
|
||||
noise_scheduler,
|
||||
unet,
|
||||
optimizer,
|
||||
text_encoder,
|
||||
dataloader,
|
||||
args,
|
||||
):
|
||||
self.vae = vae
|
||||
self.weight_dtype = weight_dtype
|
||||
self.device = device
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.unet = unet
|
||||
self.optimizer = optimizer
|
||||
self.text_encoder = text_encoder
|
||||
self.args = args
|
||||
self.mesh = xs.get_global_mesh()
|
||||
self.dataloader = iter(dataloader)
|
||||
self.global_step = 0
|
||||
|
||||
def run_optimizer(self):
|
||||
self.optimizer.step()
|
||||
|
||||
def start_training(self):
|
||||
times = []
|
||||
last_time = time.time()
|
||||
step = 0
|
||||
while True:
|
||||
if self.global_step >= self.args.max_train_steps:
|
||||
xm.mark_step()
|
||||
break
|
||||
if step == 4 and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
try:
|
||||
batch = next(self.dataloader)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
|
||||
step_time = time.time() - last_time
|
||||
if step >= 10:
|
||||
times.append(step_time)
|
||||
print(f"step: {step}, step_time: {step_time}")
|
||||
if step % 5 == 0:
|
||||
print(f"step: {step}, loss: {loss}")
|
||||
last_time = time.time()
|
||||
self.global_step += 1
|
||||
step += 1
|
||||
# print(f"Average step time: {sum(times)/len(times)}")
|
||||
xm.wait_device_ops()
|
||||
|
||||
def step_fn(
|
||||
self,
|
||||
pixel_values,
|
||||
input_ids,
|
||||
):
|
||||
with xp.Trace("model.forward"):
|
||||
self.optimizer.zero_grad()
|
||||
latents = self.vae.encode(pixel_values).latent_dist.sample()
|
||||
latents = latents * self.vae.config.scaling_factor
|
||||
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
encoder_hidden_states = self.text_encoder(input_ids, return_dict=False)[0]
|
||||
if self.args.prediction_type is not None:
|
||||
# set prediction_type of scheduler if defined
|
||||
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
|
||||
|
||||
if self.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
|
||||
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
with xp.Trace("model.backward"):
|
||||
if self.args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(self.noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if self.noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif self.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
with xp.Trace("optimizer_step"):
|
||||
self.run_optimizer()
|
||||
return loss
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-model-finetuned",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_ema_revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
|
||||
" remote repository specified with --pretrained_model_name_or_path."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loader_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to cpu."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument(
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# default to using the same revision for the non-ema model if not specified
|
||||
if args.non_ema_revision is None:
|
||||
args.non_ema_revision = args.revision
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def setup_optimizer(unet, args):
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
return optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = datasets.load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = datasets.load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_column_names(dataset, args):
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
return image_column, caption_column
|
||||
|
||||
|
||||
def main(args):
|
||||
args = parse_args()
|
||||
|
||||
_ = xp.start_server(PORT)
|
||||
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
device_ids = np.arange(num_devices)
|
||||
mesh_shape = (num_devices, 1)
|
||||
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
xs.set_global_mesh(mesh)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.non_ema_revision,
|
||||
)
|
||||
|
||||
if xm.is_master_ordinal() and args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
|
||||
|
||||
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.train()
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae,
|
||||
# non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full
|
||||
# precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
device = xm.xla_device()
|
||||
print("device: ", device)
|
||||
print("weight_dtype: ", weight_dtype)
|
||||
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=weight_dtype)
|
||||
optimizer = setup_optimizer(unet, args)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.train()
|
||||
|
||||
dataset = load_dataset(args)
|
||||
image_column, caption_column = get_column_names(dataset, args)
|
||||
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions,
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
(transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)),
|
||||
(transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["pixel_values"] = [train_transforms(image) for image in images]
|
||||
examples["input_ids"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
train_dataset.set_format("torch")
|
||||
train_dataset.set_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype)
|
||||
input_ids = torch.stack([example["input_ids"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(xr.host_index())
|
||||
sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
batch_size=args.train_batch_size,
|
||||
)
|
||||
|
||||
train_dataloader = pl.MpDeviceLoader(
|
||||
train_dataloader,
|
||||
device,
|
||||
input_sharding={
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
|
||||
},
|
||||
loader_prefetch_size=args.loader_prefetch_size,
|
||||
device_prefetch_size=args.device_prefetch_size,
|
||||
)
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size}")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
|
||||
)
|
||||
print(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
trainer = TrainSD(
|
||||
vae=vae,
|
||||
weight_dtype=weight_dtype,
|
||||
device=device,
|
||||
noise_scheduler=noise_scheduler,
|
||||
unet=unet,
|
||||
optimizer=optimizer,
|
||||
text_encoder=text_encoder,
|
||||
dataloader=train_dataloader,
|
||||
args=args,
|
||||
)
|
||||
|
||||
trainer.start_training()
|
||||
unet = trainer.unet.to("cpu")
|
||||
vae = trainer.vae.to("cpu")
|
||||
text_encoder = trainer.text_encoder.to("cpu")
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if xm.is_master_ordinal() and args.push_to_hub:
|
||||
save_model_card(args, repo_id, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
Reference in New Issue
Block a user