mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-29 07:22:12 +03:00
Merge branch 'main' into cogvideox-lora-and-training
This commit is contained in:
@@ -914,6 +914,89 @@ export_to_gif(frames, "animatelcm-motion-lora.gif")
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
## Using FreeNoise
|
||||
|
||||
[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu.
|
||||
|
||||
FreeNoise is a sampling mechanism that can generate longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper.
|
||||
|
||||
The currently supported AnimateDiff pipelines that can be used with FreeNoise are:
|
||||
- [`AnimateDiffPipeline`]
|
||||
- [`AnimateDiffControlNetPipeline`]
|
||||
- [`AnimateDiffVideoToVideoPipeline`]
|
||||
- [`AnimateDiffVideoToVideoControlNetPipeline`]
|
||||
|
||||
In order to use FreeNoise, a single line needs to be added to the inference code after loading your pipelines.
|
||||
|
||||
```diff
|
||||
+ pipe.enable_free_noise()
|
||||
```
|
||||
|
||||
After this, either a single prompt could be used, or multiple prompts can be passed as a dictionary of integer-string pairs. The integer keys of the dictionary correspond to the frame index at which the influence of that prompt would be maximum. Each frame index should map to a single string prompt. The prompts for intermediate frame indices, that are not passed in the dictionary, are created by interpolating between the frame prompts that are passed. By default, simple linear interpolation is used. However, you can customize this behaviour with a callback to the `prompt_interpolation_callback` parameter when enabling FreeNoise.
|
||||
|
||||
Full example:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter
|
||||
from diffusers.utils import export_to_video, load_image
|
||||
|
||||
# Load pipeline
|
||||
dtype = torch.float16
|
||||
motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype)
|
||||
vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype)
|
||||
|
||||
pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype)
|
||||
pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")
|
||||
|
||||
pipe.load_lora_weights(
|
||||
"wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora"
|
||||
)
|
||||
pipe.set_adapters(["lcm_lora"], [0.8])
|
||||
|
||||
# Enable FreeNoise for long prompt generation
|
||||
pipe.enable_free_noise(context_length=16, context_stride=4)
|
||||
pipe.to("cuda")
|
||||
|
||||
# Can be a single prompt, or a dictionary with frame timesteps
|
||||
prompt = {
|
||||
0: "A caterpillar on a leaf, high quality, photorealistic",
|
||||
40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic",
|
||||
80: "A cocoon on a leaf, flowers in the backgrond, photorealistic",
|
||||
120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic",
|
||||
160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic",
|
||||
200: "A beautiful butterfly, flying away in a forest, photorealistic",
|
||||
240: "A cyberpunk butterfly, neon lights, glowing",
|
||||
}
|
||||
negative_prompt = "bad quality, worst quality, jpeg artifacts"
|
||||
|
||||
# Run inference
|
||||
output = pipe(
|
||||
prompt=prompt,
|
||||
negative_prompt=negative_prompt,
|
||||
num_frames=256,
|
||||
guidance_scale=2.5,
|
||||
num_inference_steps=10,
|
||||
generator=torch.Generator("cpu").manual_seed(0),
|
||||
)
|
||||
|
||||
# Save video
|
||||
frames = output.frames[0]
|
||||
export_to_video(frames, "output.mp4", fps=16)
|
||||
```
|
||||
|
||||
### FreeNoise memory savings
|
||||
|
||||
Since FreeNoise processes multiple frames together, there are parts in the modeling where the memory required exceeds that available on normal consumer GPUs. The main memory bottlenecks that we identified are spatial and temporal attention blocks, upsampling and downsampling blocks, resnet blocks and feed-forward layers. Since most of these blocks operate effectively only on the channel/embedding dimension, one can perform chunked inference across the batch dimensions. The batch dimension in AnimateDiff are either spatial (`[B x F, H x W, C]`) or temporal (`B x H x W, F, C`) in nature (note that it may seem counter-intuitive, but the batch dimension here are correct, because spatial blocks process across the `B x F` dimension while the temporal blocks process across the `B x H x W` dimension). We introduce a `SplitInferenceModule` that makes it easier to chunk across any dimension and perform inference. This saves a lot of memory but comes at the cost of requiring more time for inference.
|
||||
|
||||
```diff
|
||||
# Load pipeline and adapters
|
||||
# ...
|
||||
+ pipe.enable_free_noise_split_inference()
|
||||
+ pipe.unet.enable_forward_chunking(16)
|
||||
```
|
||||
|
||||
The call to `pipe.enable_free_noise_split_inference` method accepts two parameters: `spatial_split_size` (defaults to `256`) and `temporal_split_size` (defaults to `16`). These can be configured based on how much VRAM you have available. A lower split size results in lower memory usage but slower inference, whereas a larger split size results in faster inference at the cost of more memory.
|
||||
|
||||
## Using `from_single_file` with the MotionAdapter
|
||||
|
||||
|
||||
152
examples/controlnet/README_sd3.md
Normal file
152
examples/controlnet/README_sd3.md
Normal file
@@ -0,0 +1,152 @@
|
||||
# ControlNet training example for Stable Diffusion 3 (SD3)
|
||||
|
||||
The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206).
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
Then cd in the `examples/controlnet` folder and run
|
||||
```bash
|
||||
pip install -r requirements_sd3.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
|
||||
Or for a default accelerate configuration without answering questions about your environment
|
||||
|
||||
```bash
|
||||
accelerate config default
|
||||
```
|
||||
|
||||
Or if your environment doesn't support an interactive shell (e.g., a notebook)
|
||||
|
||||
```python
|
||||
from accelerate.utils import write_basic_config
|
||||
write_basic_config()
|
||||
```
|
||||
|
||||
When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups.
|
||||
|
||||
## Circle filling dataset
|
||||
|
||||
The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script.
|
||||
Please download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder.
|
||||
|
||||
## Training
|
||||
|
||||
First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training.
|
||||
> [!NOTE]
|
||||
> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in:
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This will also allow us to push the trained model parameters to the Hugging Face Hub platform.
|
||||
|
||||
|
||||
Our training examples use two test conditioning images. They can be downloaded by running
|
||||
|
||||
```sh
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png
|
||||
|
||||
wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png
|
||||
```
|
||||
|
||||
Then run the following commands to train a ControlNet model.
|
||||
|
||||
```bash
|
||||
export MODEL_DIR="stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
export OUTPUT_DIR="sd3-controlnet-out"
|
||||
|
||||
accelerate launch train_controlnet_sd3.py \
|
||||
--pretrained_model_name_or_path=$MODEL_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--train_data_dir="fill50k" \
|
||||
--resolution=1024 \
|
||||
--learning_rate=1e-5 \
|
||||
--max_train_steps=15000 \
|
||||
--validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \
|
||||
--validation_prompt "red circle with blue background" "cyan circle with brown floral background" \
|
||||
--validation_steps=100 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=4
|
||||
```
|
||||
|
||||
To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected.
|
||||
|
||||
Our experiments were conducted on a single 40GB A100 GPU.
|
||||
|
||||
### Inference
|
||||
|
||||
Once training is done, we can perform inference like so:
|
||||
|
||||
```python
|
||||
from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel
|
||||
from diffusers.utils import load_image
|
||||
import torch
|
||||
|
||||
base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers"
|
||||
controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet"
|
||||
|
||||
controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)
|
||||
pipe = StableDiffusion3ControlNetPipeline.from_pretrained(
|
||||
base_model_path, controlnet=controlnet
|
||||
)
|
||||
pipe.to("cuda", torch.float16)
|
||||
|
||||
|
||||
control_image = load_image("./conditioning_image_1.png").resize((1024, 1024))
|
||||
prompt = "pale golden rod circle with old lace background"
|
||||
|
||||
# generate image
|
||||
generator = torch.manual_seed(0)
|
||||
image = pipe(
|
||||
prompt, num_inference_steps=20, generator=generator, control_image=control_image
|
||||
).images[0]
|
||||
image.save("./output.png")
|
||||
```
|
||||
|
||||
## Notes
|
||||
|
||||
### GPU usage
|
||||
|
||||
SD3 is a large model and requires a lot of GPU memory.
|
||||
We recommend using one GPU with at least 80GB of memory.
|
||||
Make sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate).
|
||||
|
||||
|
||||
## Example results
|
||||
|
||||
#### After 500 steps with batch size 8
|
||||
|
||||
| | |
|
||||
|-------------------|:-------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
|
||||
#### After 6500 steps with batch size 8:
|
||||
|
||||
| | |
|
||||
|-------------------|:-------------------------:|
|
||||
|| pale golden rod circle with old lace background |
|
||||
 |  |
|
||||
|
||||
8
examples/controlnet/requirements_sd3.txt
Normal file
8
examples/controlnet/requirements_sd3.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
datasets
|
||||
wandb
|
||||
@@ -115,3 +115,24 @@ class ControlNetSDXL(ExamplesTestsAccelerate):
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
|
||||
class ControlNetSD3(ExamplesTestsAccelerate):
|
||||
def test_controlnet_sd3(self):
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
test_args = f"""
|
||||
examples/controlnet/train_controlnet_sd3.py
|
||||
--pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe
|
||||
--dataset_name=hf-internal-testing/fill10
|
||||
--output_dir={tmpdir}
|
||||
--resolution=64
|
||||
--train_batch_size=1
|
||||
--gradient_accumulation_steps=1
|
||||
--controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3
|
||||
--max_train_steps=4
|
||||
--checkpointing_steps=2
|
||||
""".split()
|
||||
|
||||
run_command(self._launch_args + test_args)
|
||||
|
||||
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors")))
|
||||
|
||||
1415
examples/controlnet/train_controlnet_sd3.py
Normal file
1415
examples/controlnet/train_controlnet_sd3.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1597,6 +1597,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=prompts,
|
||||
)
|
||||
else:
|
||||
@@ -1606,6 +1607,7 @@ def main(args):
|
||||
tokenizers=[None, None],
|
||||
text_input_ids_list=[tokens_one, tokens_two],
|
||||
max_sequence_length=args.max_sequence_length,
|
||||
device=accelerator.device,
|
||||
prompt=args.instance_prompt,
|
||||
)
|
||||
|
||||
|
||||
167
examples/research_projects/pytorch_xla/README.md
Normal file
167
examples/research_projects/pytorch_xla/README.md
Normal file
@@ -0,0 +1,167 @@
|
||||
# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA
|
||||
|
||||
The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA.
|
||||
|
||||
It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host.
|
||||
|
||||
This script implements Distributed Data Parallel using GSPMD feature in XLA compiler
|
||||
where we shard the input batches over the TPU devices.
|
||||
|
||||
As of 9-11-2024, these are some expected step times.
|
||||
|
||||
| accelerator | global batch size | step time (seconds) |
|
||||
| ----------- | ----------------- | --------- |
|
||||
| v5p-128 | 1024 | 0.245 |
|
||||
| v5p-256 | 2048 | 0.234 |
|
||||
| v5p-512 | 4096 | 0.2498 |
|
||||
|
||||
## Create TPU
|
||||
|
||||
To create a TPU on Google Cloud first set these environment variables:
|
||||
|
||||
```bash
|
||||
export TPU_NAME=<tpu-name>
|
||||
export PROJECT_ID=<project-id>
|
||||
export ZONE=<google-cloud-zone>
|
||||
export ACCELERATOR_TYPE=<accelerator type like v5p-8>
|
||||
export RUNTIME_VERSION=<runtime version like v2-alpha-tpuv5 for v5p>
|
||||
```
|
||||
|
||||
Then run the create TPU command:
|
||||
```bash
|
||||
gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID}
|
||||
--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION}
|
||||
--reserved
|
||||
```
|
||||
|
||||
You can also use other ways to reserve TPUs like GKE or queued resources.
|
||||
|
||||
## Setup TPU environment
|
||||
|
||||
Install PyTorch and PyTorch/XLA nightly versions:
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
|
||||
pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html
|
||||
'
|
||||
```
|
||||
|
||||
Verify that PyTorch and PyTorch/XLA were installed correctly:
|
||||
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project ${PROJECT_ID} --zone ${ZONE} --worker=all \
|
||||
--command='python3 -c "import torch; import torch_xla;"'
|
||||
```
|
||||
|
||||
Install dependencies:
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
git clone https://github.com/huggingface/diffusers.git
|
||||
cd diffusers
|
||||
git checkout main
|
||||
cd examples/research_projects/pytorch_xla
|
||||
pip3 install -r requirements.txt
|
||||
pip3 install pillow --upgrade
|
||||
cd ../../..
|
||||
pip3 install .'
|
||||
```
|
||||
|
||||
## Run the training job
|
||||
|
||||
### Authenticate
|
||||
|
||||
Run the following command to authenticate your token.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
This script only trains the unet part of the network. The VAE and text encoder
|
||||
are fixed.
|
||||
|
||||
```bash
|
||||
gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
|
||||
--project=${PROJECT_ID} --zone=${ZONE} --worker=all \
|
||||
--command='
|
||||
export XLA_DISABLE_FUNCTIONALIZATION=1
|
||||
export PROFILE_DIR=/tmp/
|
||||
export CACHE_DIR=/tmp/
|
||||
export DATASET_NAME=lambdalabs/naruto-blip-captions
|
||||
export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p
|
||||
export TRAIN_STEPS=50
|
||||
export OUTPUT_DIR=/tmp/trained-model/
|
||||
python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4'
|
||||
|
||||
```
|
||||
|
||||
### Environment Envs Explained
|
||||
|
||||
* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer.
|
||||
* `PROFILE_DIR`: Specify where to put the profiling results.
|
||||
* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching.
|
||||
* `DATASET_NAME`: Dataset to train the model.
|
||||
* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis.
|
||||
* `TRAIN_STEPS`: Total number of training steps to run the training for.
|
||||
* `OUTPUT_DIR`: Directory to store the fine-tuned model.
|
||||
|
||||
## Run inference using the output model
|
||||
|
||||
To run inference using the output, you can simply load the model and pass it
|
||||
input prompts. The first pass will compile the graph and takes longer with the following passes running much faster.
|
||||
|
||||
```bash
|
||||
export CACHE_DIR=/tmp/
|
||||
```
|
||||
|
||||
```python
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
from time import time
|
||||
from diffusers import StableDiffusionPipeline
|
||||
import torch_xla.runtime as xr
|
||||
|
||||
CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
||||
if CACHE_DIR:
|
||||
xr.initialize_cache(CACHE_DIR, readonly=False)
|
||||
|
||||
def main():
|
||||
device = xm.xla_device()
|
||||
model_path = "jffacevedo/pxla_trained_model"
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to(device)
|
||||
prompt = ["A naruto with green eyes and red legs."]
|
||||
start = time()
|
||||
print("compiling...")
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
print(f"compile time: {time() - start}")
|
||||
print("generate...")
|
||||
start = time()
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
print(f"generation time (after compile) : {time() - start}")
|
||||
image.save("naruto.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
Expected Results:
|
||||
|
||||
```bash
|
||||
compiling...
|
||||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [10:03<00:00, 20.10s/it]
|
||||
compile time: 720.656970500946
|
||||
generate...
|
||||
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 30/30 [00:01<00:00, 17.65it/s]
|
||||
generation time (after compile) : 1.8461642265319824
|
||||
8
examples/research_projects/pytorch_xla/requirements.txt
Normal file
8
examples/research_projects/pytorch_xla/requirements.txt
Normal file
@@ -0,0 +1,8 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
datasets>=2.19.1
|
||||
ftfy
|
||||
tensorboard
|
||||
Jinja2
|
||||
peft==0.7.0
|
||||
@@ -0,0 +1,669 @@
|
||||
import argparse
|
||||
import os
|
||||
import random
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
import torch_xla.core.xla_model as xm
|
||||
import torch_xla.debug.profiler as xp
|
||||
import torch_xla.distributed.parallel_loader as pl
|
||||
import torch_xla.distributed.spmd as xs
|
||||
import torch_xla.runtime as xr
|
||||
from huggingface_hub import create_repo, upload_folder
|
||||
from torchvision import transforms
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
StableDiffusionPipeline,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.training_utils import compute_snr
|
||||
from diffusers.utils import is_wandb_available
|
||||
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
pass
|
||||
|
||||
PROFILE_DIR = os.environ.get("PROFILE_DIR", None)
|
||||
CACHE_DIR = os.environ.get("CACHE_DIR", None)
|
||||
if CACHE_DIR:
|
||||
xr.initialize_cache(CACHE_DIR, readonly=False)
|
||||
xr.use_spmd()
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/naruto-blip-captions": ("image", "text"),
|
||||
}
|
||||
PORT = 9012
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
repo_folder: str = None,
|
||||
):
|
||||
model_description = f"""
|
||||
# Text-to-image finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. \n
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
import torch
|
||||
import os
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
import torch_xla.core.xla_model as xm
|
||||
from time import time
|
||||
from typing import Tuple
|
||||
from diffusers import StableDiffusionPipeline
|
||||
|
||||
def main(args):
|
||||
device = xm.xla_device()
|
||||
model_path = <output_dir>
|
||||
pipe = StableDiffusionPipeline.from_pretrained(
|
||||
model_path,
|
||||
torch_dtype=torch.bfloat16
|
||||
)
|
||||
pipe.to(device)
|
||||
prompt = ["A naruto with green eyes and red legs."]
|
||||
image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0]
|
||||
image.save("naruto.png")
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* Steps: {args.max_train_steps}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
|
||||
model_card = load_or_create_model_card(
|
||||
repo_id_or_path=repo_id,
|
||||
from_training=True,
|
||||
license="creativeml-openrail-m",
|
||||
base_model=args.pretrained_model_name_or_path,
|
||||
model_description=model_description,
|
||||
inference=True,
|
||||
)
|
||||
|
||||
tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"]
|
||||
model_card = populate_model_card(model_card, tags=tags)
|
||||
|
||||
model_card.save(os.path.join(repo_folder, "README.md"))
|
||||
|
||||
|
||||
class TrainSD:
|
||||
def __init__(
|
||||
self,
|
||||
vae,
|
||||
weight_dtype,
|
||||
device,
|
||||
noise_scheduler,
|
||||
unet,
|
||||
optimizer,
|
||||
text_encoder,
|
||||
dataloader,
|
||||
args,
|
||||
):
|
||||
self.vae = vae
|
||||
self.weight_dtype = weight_dtype
|
||||
self.device = device
|
||||
self.noise_scheduler = noise_scheduler
|
||||
self.unet = unet
|
||||
self.optimizer = optimizer
|
||||
self.text_encoder = text_encoder
|
||||
self.args = args
|
||||
self.mesh = xs.get_global_mesh()
|
||||
self.dataloader = iter(dataloader)
|
||||
self.global_step = 0
|
||||
|
||||
def run_optimizer(self):
|
||||
self.optimizer.step()
|
||||
|
||||
def start_training(self):
|
||||
times = []
|
||||
last_time = time.time()
|
||||
step = 0
|
||||
while True:
|
||||
if self.global_step >= self.args.max_train_steps:
|
||||
xm.mark_step()
|
||||
break
|
||||
if step == 4 and PROFILE_DIR is not None:
|
||||
xm.wait_device_ops()
|
||||
xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration)
|
||||
try:
|
||||
batch = next(self.dataloader)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
break
|
||||
loss = self.step_fn(batch["pixel_values"], batch["input_ids"])
|
||||
step_time = time.time() - last_time
|
||||
if step >= 10:
|
||||
times.append(step_time)
|
||||
print(f"step: {step}, step_time: {step_time}")
|
||||
if step % 5 == 0:
|
||||
print(f"step: {step}, loss: {loss}")
|
||||
last_time = time.time()
|
||||
self.global_step += 1
|
||||
step += 1
|
||||
# print(f"Average step time: {sum(times)/len(times)}")
|
||||
xm.wait_device_ops()
|
||||
|
||||
def step_fn(
|
||||
self,
|
||||
pixel_values,
|
||||
input_ids,
|
||||
):
|
||||
with xp.Trace("model.forward"):
|
||||
self.optimizer.zero_grad()
|
||||
latents = self.vae.encode(pixel_values).latent_dist.sample()
|
||||
latents = latents * self.vae.config.scaling_factor
|
||||
noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype)
|
||||
bsz = latents.shape[0]
|
||||
timesteps = torch.randint(
|
||||
0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device
|
||||
)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
encoder_hidden_states = self.text_encoder(input_ids, return_dict=False)[0]
|
||||
if self.args.prediction_type is not None:
|
||||
# set prediction_type of scheduler if defined
|
||||
self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type)
|
||||
|
||||
if self.noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif self.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = self.noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
|
||||
model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
|
||||
with xp.Trace("model.backward"):
|
||||
if self.args.snr_gamma is None:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
else:
|
||||
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
|
||||
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
|
||||
# This is discussed in Section 4.2 of the same paper.
|
||||
snr = compute_snr(self.noise_scheduler, timesteps)
|
||||
mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(
|
||||
dim=1
|
||||
)[0]
|
||||
if self.noise_scheduler.config.prediction_type == "epsilon":
|
||||
mse_loss_weights = mse_loss_weights / snr
|
||||
elif self.noise_scheduler.config.prediction_type == "v_prediction":
|
||||
mse_loss_weights = mse_loss_weights / (snr + 1)
|
||||
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
|
||||
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
|
||||
loss = loss.mean()
|
||||
loss.backward()
|
||||
with xp.Trace("optimizer_step"):
|
||||
self.run_optimizer()
|
||||
return loss
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1."
|
||||
)
|
||||
parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--variant",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="sd-model-finetuned",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether to center crop the input images to the resolution. If not set, the images will be randomly"
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--snr_gamma",
|
||||
type=float,
|
||||
default=None,
|
||||
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
|
||||
"More details here: https://arxiv.org/abs/2303.09556.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--non_ema_revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help=(
|
||||
"Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or"
|
||||
" remote repository specified with --pretrained_model_name_or_path."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--loader_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to cpu."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device_prefetch_size",
|
||||
type=int,
|
||||
default=1,
|
||||
help=("Number of subprocesses to use for data loading to tpu from cpu. "),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument(
|
||||
"--prediction_type",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# default to using the same revision for the non-ema model if not specified
|
||||
if args.non_ema_revision is None:
|
||||
args.non_ema_revision = args.revision
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def setup_optimizer(unet, args):
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
return optimizer_cls(
|
||||
unet.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
foreach=True,
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = datasets.load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
data_dir=args.train_data_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = datasets.load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
return dataset
|
||||
|
||||
|
||||
def get_column_names(dataset, args):
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
return image_column, caption_column
|
||||
|
||||
|
||||
def main(args):
|
||||
args = parse_args()
|
||||
|
||||
_ = xp.start_server(PORT)
|
||||
|
||||
num_devices = xr.global_runtime_device_count()
|
||||
device_ids = np.arange(num_devices)
|
||||
mesh_shape = (num_devices, 1)
|
||||
mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y"))
|
||||
xs.set_global_mesh(mesh)
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="vae",
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="unet",
|
||||
revision=args.non_ema_revision,
|
||||
)
|
||||
|
||||
if xm.is_master_ordinal() and args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
tokenizer = CLIPTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
)
|
||||
|
||||
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
|
||||
|
||||
unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
|
||||
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.train()
|
||||
|
||||
# For mixed precision training we cast all non-trainable weights (vae,
|
||||
# non-lora text_encoder and non-lora unet) to half-precision
|
||||
# as these weights are only used for inference, keeping weights in full
|
||||
# precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
device = xm.xla_device()
|
||||
print("device: ", device)
|
||||
print("weight_dtype: ", weight_dtype)
|
||||
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=weight_dtype)
|
||||
optimizer = setup_optimizer(unet, args)
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.train()
|
||||
|
||||
dataset = load_dataset(args)
|
||||
image_column, caption_column = get_column_names(dataset, args)
|
||||
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions,
|
||||
max_length=tokenizer.model_max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
return inputs.input_ids
|
||||
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
(transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)),
|
||||
(transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["pixel_values"] = [train_transforms(image) for image in images]
|
||||
examples["input_ids"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
train_dataset = dataset["train"]
|
||||
train_dataset.set_format("torch")
|
||||
train_dataset.set_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
pixel_values = torch.stack([example["pixel_values"] for example in examples])
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype)
|
||||
input_ids = torch.stack([example["input_ids"] for example in examples])
|
||||
return {"pixel_values": pixel_values, "input_ids": input_ids}
|
||||
|
||||
g = torch.Generator()
|
||||
g.manual_seed(xr.host_index())
|
||||
sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g)
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
sampler=sampler,
|
||||
collate_fn=collate_fn,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
batch_size=args.train_batch_size,
|
||||
)
|
||||
|
||||
train_dataloader = pl.MpDeviceLoader(
|
||||
train_dataloader,
|
||||
device,
|
||||
input_sharding={
|
||||
"pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True),
|
||||
"input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True),
|
||||
},
|
||||
loader_prefetch_size=args.loader_prefetch_size,
|
||||
device_prefetch_size=args.device_prefetch_size,
|
||||
)
|
||||
|
||||
if xm.is_master_ordinal():
|
||||
print("***** Running training *****")
|
||||
print(f"Instantaneous batch size per device = {args.train_batch_size}")
|
||||
print(
|
||||
f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}"
|
||||
)
|
||||
print(f" Total optimization steps = {args.max_train_steps}")
|
||||
|
||||
trainer = TrainSD(
|
||||
vae=vae,
|
||||
weight_dtype=weight_dtype,
|
||||
device=device,
|
||||
noise_scheduler=noise_scheduler,
|
||||
unet=unet,
|
||||
optimizer=optimizer,
|
||||
text_encoder=text_encoder,
|
||||
dataloader=train_dataloader,
|
||||
args=args,
|
||||
)
|
||||
|
||||
trainer.start_training()
|
||||
unet = trainer.unet.to("cpu")
|
||||
vae = trainer.vae.to("cpu")
|
||||
text_encoder = trainer.text_encoder.to("cpu")
|
||||
|
||||
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
revision=args.revision,
|
||||
variant=args.variant,
|
||||
)
|
||||
pipeline.save_pretrained(args.output_dir)
|
||||
|
||||
if xm.is_master_ordinal() and args.push_to_hub:
|
||||
save_model_card(args, repo_id, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -690,7 +690,7 @@ class FluxPosEmbed(nn.Module):
|
||||
n_axes = ids.shape[-1]
|
||||
cos_out = []
|
||||
sin_out = []
|
||||
pos = ids.squeeze().float()
|
||||
pos = ids.float()
|
||||
is_mps = ids.device.type == "mps"
|
||||
freqs_dtype = torch.float32 if is_mps else torch.float64
|
||||
for i in range(n_axes):
|
||||
|
||||
@@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
||||
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
||||
)
|
||||
img_ids = img_ids[0]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=0)
|
||||
image_rotary_emb = self.pos_embed(ids)
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
USE_PEFT_BACKEND,
|
||||
deprecate,
|
||||
is_torch_xla_available,
|
||||
logging,
|
||||
replace_example_docstring,
|
||||
scale_lora_layers,
|
||||
@@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput
|
||||
from .safety_checker import StableDiffusionSafetyChecker
|
||||
|
||||
|
||||
if is_torch_xla_available():
|
||||
import torch_xla.core.xla_model as xm
|
||||
|
||||
XLA_AVAILABLE = True
|
||||
else:
|
||||
XLA_AVAILABLE = False
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
EXAMPLE_DOC_STRING = """
|
||||
@@ -1036,6 +1044,9 @@ class StableDiffusionPipeline(
|
||||
step_idx = i // getattr(self.scheduler, "order", 1)
|
||||
callback(step_idx, t, latents)
|
||||
|
||||
if XLA_AVAILABLE:
|
||||
xm.mark_step()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[
|
||||
0
|
||||
|
||||
@@ -38,7 +38,20 @@ class BatchedBrownianTree:
|
||||
except TypeError:
|
||||
seed = [seed]
|
||||
self.batched = False
|
||||
self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
|
||||
self.trees = [
|
||||
torchsde.BrownianInterval(
|
||||
t0=t0,
|
||||
t1=t1,
|
||||
size=w0.shape,
|
||||
dtype=w0.dtype,
|
||||
device=w0.device,
|
||||
entropy=s,
|
||||
tol=1e-6,
|
||||
pool_size=24,
|
||||
halfway_tree=True,
|
||||
)
|
||||
for s in seed
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def sort(a, b):
|
||||
|
||||
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
|
||||
class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxImg2ImgPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
@@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin
|
||||
enable_full_determinism()
|
||||
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.")
|
||||
class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin):
|
||||
pipeline_class = FluxInpaintPipeline
|
||||
params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"])
|
||||
batch_params = frozenset(["prompt"])
|
||||
test_xformers_attention = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
Reference in New Issue
Block a user