1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Research Projects] ORPO diffusion for alignment (#7423)

* barebones orpo

* remove reference model.

* full implementation

* change default of beta_orpo

* add a training command.

* fix: dataloading issues.

* interpreting the formulation.

* revert styling

* add: wds full blown version

* fix: per_gpu_batch_siz

* start debuggin

* debugging

* remove print

* fix

* remove filter keys.

* turn on non-blocking calls.

* device_placement

* let's see.

* add bigger training run command

* reinitialize generator for fair repro

* add: detailed readme and requirements

---------

Co-authored-by: Sayak Paul <sayakpaul@Sayaks-MacBook-Pro-2.local>
This commit is contained in:
Sayak Paul
2024-03-25 08:37:41 +05:30
committed by GitHub
parent f7dfcfd971
commit e29f16cfaa
4 changed files with 2307 additions and 0 deletions

View File

@@ -0,0 +1,121 @@
This project is an attempt to check if it's possible to apply to [ORPO](https://arxiv.org/abs/2403.07691) on a text-conditioned diffusion model to align it on preference data WITHOUT a reference model. The implementation is based on https://github.com/huggingface/trl/pull/1435/.
> [!WARNING]
> We assume that MSE in the diffusion formulation approximates the log-probs as required by ORPO (hat-tip to [@kashif](https://github.com/kashif) for the idea). So, please consider this to be extremely experimental.
## Training
Here's training command you can use on a 40GB A100 to validate things on a [small preference
dataset](https://hf.co/datasets/kashif/pickascore):
```bash
accelerate launch train_diffusion_orpo_sdxl_lora.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--output_dir="diffusion-sdxl-orpo" \
--mixed_precision="fp16" \
--dataset_name=kashif/pickascore \
--train_batch_size=8 \
--gradient_accumulation_steps=2 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--learning_rate=1e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=2000 \
--checkpointing_steps=500 \
--run_validation --validation_steps=50 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
We also provide a simple script to scale up the training on the [yuvalkirstain/pickapic_v2](https://huggingface.co/datasets/yuvalkirstain/pickapic_v2) dataset:
```bash
accelerate launch --multi_gpu train_diffusion_orpo_sdxl_lora_wds.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--pretrained_vae_model_name_or_path=madebyollin/sdxl-vae-fp16-fix \
--dataset_path="pipe:aws s3 cp s3://diffusion-preference-opt/{00000..00644}.tar -" \
--output_dir="diffusion-sdxl-orpo-wds" \
--mixed_precision="fp16" \
--gradient_accumulation_steps=1 \
--gradient_checkpointing \
--use_8bit_adam \
--rank=8 \
--dataloader_num_workers=8 \
--learning_rate=3e-5 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=50000 \
--checkpointing_steps=2000 \
--run_validation --validation_steps=500 \
--seed="0" \
--report_to="wandb" \
--push_to_hub
```
We tested the above on a node of 8 H100s but it should also work on A100s. It requires the `webdataset` library for faster dataloading. Note that we kept the dataset shards on an S3 bucket but it should be also possible to have them stored locally.
You can use the code below to convert the original dataset into `webdataset` shards:
```python
import os
import io
import ray
import webdataset as wds
from datasets import Dataset
from PIL import Image
ray.init(num_cpus=8)
def convert_to_image(im_bytes):
return Image.open(io.BytesIO(im_bytes)).convert("RGB")
def main():
dataset_path = "/pickapic_v2/data"
wds_shards_path = "/pickapic_v2_webdataset"
# get all .parquet files in the dataset path
dataset_files = [
os.path.join(dataset_path, f)
for f in os.listdir(dataset_path)
if f.endswith(".parquet")
]
@ray.remote
def create_shard(path):
# get basename of the file
basename = os.path.basename(path)
# get the shard number data-00123-of-01034.parquet -> 00123
shard_num = basename.split("-")[1]
dataset = Dataset.from_parquet(path)
# create a webdataset shard
shard = wds.TarWriter(os.path.join(wds_shards_path, f"{shard_num}.tar"))
for i, example in enumerate(dataset):
wds_example = {
"__key__": str(i),
"original_prompt.txt": example["caption"],
"jpg_0.jpg": convert_to_image(example["jpg_0"]),
"jpg_1.jpg": convert_to_image(example["jpg_1"]),
"label_0.txt": str(example["label_0"]),
"label_1.txt": str(example["label_1"])
}
shard.write(wds_example)
shard.close()
futures = [create_shard.remote(path) for path in dataset_files]
ray.get(futures)
if __name__ == "__main__":
main()
```
## Inference
Refer to [sayakpaul/diffusion-sdxl-orpo](https://huggingface.co/sayakpaul/diffusion-sdxl-orpo) for an experimental checkpoint.

View File

@@ -0,0 +1,7 @@
datasets
accelerate
transformers
torchvision
wandb
peft
webdataset

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff