mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[research_projects] add flux training script with quantization (#9754)
* add flux training script with quantization * remove exclamation
This commit is contained in:
166
examples/research_projects/flux_lora_quantization/README.md
Normal file
166
examples/research_projects/flux_lora_quantization/README.md
Normal file
@@ -0,0 +1,166 @@
|
||||
## LoRA fine-tuning Flux.1 Dev with quantization
|
||||
|
||||
> [!NOTE]
|
||||
> This example is educational in nature and fixes some arguments to keep things simple. It should act as a reference to build things further.
|
||||
|
||||
This example shows how to fine-tune [Flux.1 Dev](https://huggingface.co/black-forest-labs/FLUX.1-dev) with LoRA and quantization. We show this by using the [`Norod78/Yarn-art-style`](https://huggingface.co/datasets/Norod78/Yarn-art-style) dataset. Steps below summarize the workflow:
|
||||
|
||||
* We precompute the text embeddings in `compute_embeddings.py` and serialize them into a parquet file.
|
||||
* `train_dreambooth_lora_flux_miniature.py` takes care of training:
|
||||
* Since we already precomputed the text embeddings, we don't load the text encoders.
|
||||
* We load the VAE and use it to precompute the image latents and we then delete it.
|
||||
* Load the Flux transformer, quantize it with the [NF4 datatype](https://arxiv.org/abs/2305.14314) through `bitsandbytes`, prepare it for 4bit training.
|
||||
* Add LoRA adapter layers to it and then ensure they are kept in FP32 precision.
|
||||
* Train!
|
||||
|
||||
To run training in a memory-optimized manner, we additionally use:
|
||||
|
||||
* 8Bit Adam
|
||||
* Gradient checkpointing
|
||||
|
||||
We have tested the scripts on a 24GB 4090. It works on a free-tier Colab Notebook, too, but it's extremely slow.
|
||||
|
||||
## Training
|
||||
|
||||
Ensure you have installed the required libraries:
|
||||
|
||||
```bash
|
||||
pip install -U transformers accelerate bitsandbytes peft datasets
|
||||
pip install git+https://github.com/huggingface/diffusers -U
|
||||
```
|
||||
|
||||
Now, compute the text embeddings:
|
||||
|
||||
```bash
|
||||
python compute_embeddings.py
|
||||
```
|
||||
|
||||
It should create a file named `embeddings.parquet`. We're then ready to launch training. First, authenticate so that you can access the Flux.1 Dev model:
|
||||
|
||||
```bash
|
||||
huggingface-cli
|
||||
```
|
||||
|
||||
Then launch:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=accelerate.yaml \
|
||||
train_dreambooth_lora_flux_miniature.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--data_df_path="embeddings.parquet" \
|
||||
--output_dir="yarn_art_lora_flux_nf4" \
|
||||
--mixed_precision="fp16" \
|
||||
--use_8bit_adam \
|
||||
--weighting_scheme="none" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--learning_rate=1e-4 \
|
||||
--guidance_scale=1 \
|
||||
--report_to="wandb" \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--cache_latents \
|
||||
--rank=4 \
|
||||
--max_train_steps=700 \
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
We can direcly pass a quantized checkpoint path, too:
|
||||
|
||||
```diff
|
||||
+ --quantized_model_path="hf-internal-testing/flux.1-dev-nf4-pkg"
|
||||
```
|
||||
|
||||
Depending on the machine, training time will vary but for our case, it was 1.5 hours. It maybe possible to speed this up by using `torch.bfloat16`.
|
||||
|
||||
We support training with the DeepSpeed Zero2 optimizer, too. To use it, first install DeepSpeed:
|
||||
|
||||
```bash
|
||||
pip install -Uq deepspeed
|
||||
```
|
||||
|
||||
And then launch:
|
||||
|
||||
```bash
|
||||
accelerate launch --config_file=ds2.yaml \
|
||||
train_dreambooth_lora_flux_miniature.py \
|
||||
--pretrained_model_name_or_path="black-forest-labs/FLUX.1-dev" \
|
||||
--data_df_path="embeddings.parquet" \
|
||||
--output_dir="yarn_art_lora_flux_nf4" \
|
||||
--mixed_precision="no" \
|
||||
--use_8bit_adam \
|
||||
--weighting_scheme="none" \
|
||||
--resolution=1024 \
|
||||
--train_batch_size=1 \
|
||||
--repeats=1 \
|
||||
--learning_rate=1e-4 \
|
||||
--guidance_scale=1 \
|
||||
--report_to="wandb" \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--cache_latents \
|
||||
--rank=4 \
|
||||
--max_train_steps=700 \
|
||||
--seed="0"
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
When loading the LoRA params (that were obtained on a quantized base model) and merging them into the base model, it is recommended to first dequantize the base model, merge the LoRA params into it, and then quantize the model again. This is because merging into 4bit quantized models can lead to some rounding errors. Below, we provide an end-to-end example:
|
||||
|
||||
1. First, load the original model and merge the LoRA params into it:
|
||||
|
||||
```py
|
||||
from diffusers import FluxPipeline
|
||||
import torch
|
||||
|
||||
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
ckpt_id, text_encoder=None, text_encoder_2=None, torch_dtype=torch.float16
|
||||
)
|
||||
pipeline.load_lora_weights("yarn_art_lora_flux_nf4", weight_name="pytorch_lora_weights.safetensors")
|
||||
pipeline.fuse_lora()
|
||||
pipeline.unload_lora_weights()
|
||||
|
||||
pipeline.transformer.save_pretrained("fused_transformer")
|
||||
```
|
||||
|
||||
2. Quantize the model and run inference
|
||||
|
||||
```py
|
||||
from diffusers import AutoPipelineForText2Image, FluxTransformer2DModel, BitsAndBytesConfig
|
||||
import torch
|
||||
|
||||
ckpt_id = "black-forest-labs/FLUX.1-dev"
|
||||
bnb_4bit_compute_dtype = torch.float16
|
||||
nf4_config = BitsAndBytesConfig(
|
||||
load_in_4bit=True,
|
||||
bnb_4bit_quant_type="nf4",
|
||||
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype,
|
||||
)
|
||||
transformer = FluxTransformer2DModel.from_pretrained(
|
||||
"fused_transformer",
|
||||
quantization_config=nf4_config,
|
||||
torch_dtype=bnb_4bit_compute_dtype,
|
||||
)
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
ckpt_id, transformer=transformer, torch_dtype=bnb_4bit_compute_dtype
|
||||
)
|
||||
pipeline.enable_model_cpu_offload()
|
||||
|
||||
image = pipeline(
|
||||
"a puppy in a pond, yarn art style", num_inference_steps=28, guidance_scale=3.5, height=768
|
||||
).images[0]
|
||||
image.save("yarn_merged.png")
|
||||
```
|
||||
|
||||
| Dequantize, merge, quantize | Merging directly into quantized model |
|
||||
|-------|-------|
|
||||
|  |  |
|
||||
|
||||
As we can notice the first column result follows the style more closely.
|
||||
@@ -0,0 +1,17 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
distributed_type: NO
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: true
|
||||
gpu_ids: all
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: bf16
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
@@ -0,0 +1,107 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub.utils import insecure_hashlib
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import T5EncoderModel
|
||||
|
||||
from diffusers import FluxPipeline
|
||||
|
||||
|
||||
MAX_SEQ_LENGTH = 77
|
||||
OUTPUT_PATH = "embeddings.parquet"
|
||||
|
||||
|
||||
def generate_image_hash(image):
|
||||
return insecure_hashlib.sha256(image.tobytes()).hexdigest()
|
||||
|
||||
|
||||
def load_flux_dev_pipeline():
|
||||
id = "black-forest-labs/FLUX.1-dev"
|
||||
text_encoder = T5EncoderModel.from_pretrained(id, subfolder="text_encoder_2", load_in_8bit=True, device_map="auto")
|
||||
pipeline = FluxPipeline.from_pretrained(
|
||||
id, text_encoder_2=text_encoder, transformer=None, vae=None, device_map="balanced"
|
||||
)
|
||||
return pipeline
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_embeddings(pipeline, prompts, max_sequence_length):
|
||||
all_prompt_embeds = []
|
||||
all_pooled_prompt_embeds = []
|
||||
all_text_ids = []
|
||||
for prompt in tqdm(prompts, desc="Encoding prompts."):
|
||||
(
|
||||
prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
text_ids,
|
||||
) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, max_sequence_length=max_sequence_length)
|
||||
all_prompt_embeds.append(prompt_embeds)
|
||||
all_pooled_prompt_embeds.append(pooled_prompt_embeds)
|
||||
all_text_ids.append(text_ids)
|
||||
|
||||
max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
|
||||
print(f"Max memory allocated: {max_memory:.3f} GB")
|
||||
return all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids
|
||||
|
||||
|
||||
def run(args):
|
||||
dataset = load_dataset("Norod78/Yarn-art-style", split="train")
|
||||
image_prompts = {generate_image_hash(sample["image"]): sample["text"] for sample in dataset}
|
||||
all_prompts = list(image_prompts.values())
|
||||
print(f"{len(all_prompts)=}")
|
||||
|
||||
pipeline = load_flux_dev_pipeline()
|
||||
all_prompt_embeds, all_pooled_prompt_embeds, all_text_ids = compute_embeddings(
|
||||
pipeline, all_prompts, args.max_sequence_length
|
||||
)
|
||||
|
||||
data = []
|
||||
for i, (image_hash, _) in enumerate(image_prompts.items()):
|
||||
data.append((image_hash, all_prompt_embeds[i], all_pooled_prompt_embeds[i], all_text_ids[i]))
|
||||
print(f"{len(data)=}")
|
||||
|
||||
# Create a DataFrame
|
||||
embedding_cols = ["prompt_embeds", "pooled_prompt_embeds", "text_ids"]
|
||||
df = pd.DataFrame(data, columns=["image_hash"] + embedding_cols)
|
||||
print(f"{len(df)=}")
|
||||
|
||||
# Convert embedding lists to arrays (for proper storage in parquet)
|
||||
for col in embedding_cols:
|
||||
df[col] = df[col].apply(lambda x: x.cpu().numpy().flatten().tolist())
|
||||
|
||||
# Save the dataframe to a parquet file
|
||||
df.to_parquet(args.output_path)
|
||||
print(f"Data successfully serialized to {args.output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
"--max_sequence_length",
|
||||
type=int,
|
||||
default=MAX_SEQ_LENGTH,
|
||||
help="Maximum sequence length to use for computing the embeddings. The more the higher computational costs.",
|
||||
)
|
||||
parser.add_argument("--output_path", type=str, default=OUTPUT_PATH, help="Path to serialize the parquet file.")
|
||||
args = parser.parse_args()
|
||||
|
||||
run(args)
|
||||
23
examples/research_projects/flux_lora_quantization/ds2.yaml
Normal file
23
examples/research_projects/flux_lora_quantization/ds2.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: false
|
||||
deepspeed_config:
|
||||
gradient_accumulation_steps: 1
|
||||
gradient_clipping: 1.0
|
||||
offload_optimizer_device: cpu
|
||||
offload_param_device: cpu
|
||||
zero3_init_flag: false
|
||||
zero_stage: 2
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
enable_cpu_affinity: false
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
mixed_precision: 'no'
|
||||
num_machines: 1
|
||||
num_processes: 1
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user