1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Refactor instructpix2pix lora to support peft (#10205)

* make base code changes referred from train_instructpix2pix script in examples

* change code to use PEFT as discussed in issue 10062

* update README training command

* update README training command

* refactor variable name and freezing unet

* Update examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* update README installation instructions.

* cleanup code using make style and quality

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Rahul Raman
2025-01-06 22:30:45 -08:00
committed by GitHub
parent b94cfd7937
commit f1e0c7ce4a
2 changed files with 262 additions and 124 deletions

View File

@@ -2,6 +2,34 @@
This extended LoRA training script was authored by [Aiden-Frost](https://github.com/Aiden-Frost).
This is an experimental LoRA extension of [this example](https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py). This script provides further support add LoRA layers for unet model.
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Note also that we use PEFT library as backend for LoRA training, make sure to have `peft>=0.6.0` installed in your environment.
## Training script example
```bash
@@ -9,7 +37,7 @@ export MODEL_ID="timbrooks/instruct-pix2pix"
export DATASET_ID="instruction-tuning-sd/cartoonization"
export OUTPUT_DIR="instructPix2Pix-cartoonization"
accelerate launch finetune_instruct_pix2pix.py \
accelerate launch train_instruct_pix2pix_lora.py \
--pretrained_model_name_or_path=$MODEL_ID \
--dataset_name=$DATASET_ID \
--enable_xformers_memory_efficient_attention \
@@ -24,7 +52,10 @@ accelerate launch finetune_instruct_pix2pix.py \
--rank=4 \
--output_dir=$OUTPUT_DIR \
--report_to=wandb \
--push_to_hub
--push_to_hub \
--original_image_column="original_image" \
--edited_image_column="cartoonized_image" \
--edit_prompt_column="edit_prompt"
```
## Inference

View File

@@ -14,7 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Script to fine-tune Stable Diffusion for InstructPix2Pix."""
"""
Script to fine-tune Stable Diffusion for LORA InstructPix2Pix.
Base code referred from: https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/train_instruct_pix2pix.py
"""
import argparse
import logging
@@ -30,6 +33,7 @@ import numpy as np
import PIL
import requests
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
@@ -39,21 +43,28 @@ from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from peft import LoraConfig
from peft.utils import get_peft_model_state_dict
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionInstructPix2PixPipeline, UNet2DConditionModel
from diffusers.models.lora import LoRALinearLayer
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.training_utils import EMAModel, cast_training_params
from diffusers.utils import check_min_version, convert_state_dict_to_diffusers, deprecate, is_wandb_available
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
import wandb
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.26.0.dev0")
check_min_version("0.32.0.dev0")
logger = get_logger(__name__, log_level="INFO")
@@ -63,6 +74,92 @@ DATASET_NAME_MAPPING = {
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
def save_model_card(
repo_id: str,
images: list = None,
base_model: str = None,
dataset_name: str = None,
repo_folder: str = None,
):
img_str = ""
if images is not None:
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
img_str += f"![img_{i}](./image_{i}.png)\n"
model_description = f"""
# LoRA text2image fine-tuning - {repo_id}
These are LoRA adaption weights for {base_model}. The weights were fine-tuned on the {dataset_name} dataset. You can find some example images in the following. \n
{img_str}
"""
model_card = load_or_create_model_card(
repo_id_or_path=repo_id,
from_training=True,
license="creativeml-openrail-m",
base_model=base_model,
model_description=model_description,
inference=True,
)
tags = [
"stable-diffusion",
"stable-diffusion-diffusers",
"text-to-image",
"instruct-pix2pix",
"diffusers",
"diffusers-training",
"lora",
]
model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(repo_folder, "README.md"))
def log_validation(
pipeline,
args,
accelerator,
generator,
):
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt)
tracker.log({"validation": wandb_table})
return edited_images
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script for InstructPix2Pix.")
parser.add_argument(
@@ -417,11 +514,6 @@ def main():
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -467,49 +559,58 @@ def main():
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
# InstructPix2Pix uses an additional image for conditioning. To accommodate that,
# it uses 8 channels (instead of 4) in the first (conv) layer of the UNet. This UNet is
# then fine-tuned on the custom InstructPix2Pix dataset. This modified UNet is initialized
# from the pre-trained checkpoints. For the extra channels added to the first layer, they are
# initialized to zero.
logger.info("Initializing the InstructPix2Pix UNet from the pretrained UNet.")
in_channels = 8
out_channels = unet.conv_in.out_channels
unet.register_to_config(in_channels=in_channels)
with torch.no_grad():
new_conv_in = nn.Conv2d(
in_channels, out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding
)
new_conv_in.weight.zero_()
new_conv_in.weight[:, :in_channels, :, :].copy_(unet.conv_in.weight)
unet.conv_in = new_conv_in
# Freeze vae, text_encoder and unet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# referred to https://github.com/huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image_lora.py
unet_lora_parameters = []
for attn_processor_name, attn_processor in unet.attn_processors.items():
# Parse the attention module.
attn_module = unet
for n in attn_processor_name.split(".")[:-1]:
attn_module = getattr(attn_module, n)
# For mixed precision training we cast all non-trainable weights (vae, non-lora text_encoder and non-lora unet) to half-precision
# as these weights are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Set the `lora_layer` attribute of the attention-related matrices.
attn_module.to_q.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=args.rank
)
)
attn_module.to_k.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=args.rank
)
)
# Freeze the unet parameters before adding adapters
unet.requires_grad_(False)
attn_module.to_v.set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=args.rank
)
)
attn_module.to_out[0].set_lora_layer(
LoRALinearLayer(
in_features=attn_module.to_out[0].in_features,
out_features=attn_module.to_out[0].out_features,
rank=args.rank,
)
)
unet_lora_config = LoraConfig(
r=args.rank,
lora_alpha=args.rank,
init_lora_weights="gaussian",
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
)
# Accumulate the LoRA params to optimize.
unet_lora_parameters.extend(attn_module.to_q.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_k.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_v.lora_layer.parameters())
unet_lora_parameters.extend(attn_module.to_out[0].lora_layer.parameters())
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# Add adapter and make sure the trainable params are in float32.
unet.add_adapter(unet_lora_config)
if args.mixed_precision == "fp16":
# only upcast trainable parameters (LoRA) into fp32
cast_training_params(unet, dtype=torch.float32)
# Create EMA for the unet.
if args.use_ema:
@@ -528,6 +629,13 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
trainable_params = filter(lambda p: p.requires_grad, unet.parameters())
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
@@ -540,7 +648,8 @@ def main():
model.save_pretrained(os.path.join(output_dir, "unet"))
# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()
def load_model_hook(models, input_dir):
if args.use_ema:
@@ -589,9 +698,9 @@ def main():
else:
optimizer_cls = torch.optim.AdamW
# train on only unet_lora_parameters
# train on only lora_layers
optimizer = optimizer_cls(
unet_lora_parameters,
trainable_params,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
@@ -730,22 +839,27 @@ def main():
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
# Check the PR https://github.com/huggingface/diffusers/pull/8312 for detailed explanation.
num_warmup_steps_for_scheduler = args.lr_warmup_steps * accelerator.num_processes
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
len_train_dataloader_after_sharding = math.ceil(len(train_dataloader) / accelerator.num_processes)
num_update_steps_per_epoch = math.ceil(len_train_dataloader_after_sharding / args.gradient_accumulation_steps)
num_training_steps_for_scheduler = (
args.num_train_epochs * num_update_steps_per_epoch * accelerator.num_processes
)
else:
num_training_steps_for_scheduler = args.max_train_steps * accelerator.num_processes
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_warmup_steps=num_warmup_steps_for_scheduler,
num_training_steps=num_training_steps_for_scheduler,
)
# Prepare everything with our `accelerator`.
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, unet_lora_parameters, optimizer, train_dataloader, lr_scheduler
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
if args.use_ema:
@@ -765,8 +879,14 @@ def main():
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
if num_training_steps_for_scheduler != args.max_train_steps * accelerator.num_processes:
logger.warning(
f"The length of the 'train_dataloader' after 'accelerator.prepare' ({len(train_dataloader)}) does not match "
f"the expected length ({len_train_dataloader_after_sharding}) when the learning rate scheduler was created. "
f"This inconsistency may result in the learning rate scheduler not functioning properly."
)
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
@@ -885,7 +1005,7 @@ def main():
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
# Predict the noise residual and compute loss
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states).sample
model_pred = unet(concatenated_noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
@@ -895,7 +1015,7 @@ def main():
# Backpropagate
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(unet_lora_parameters, args.max_grad_norm)
accelerator.clip_grad_norm_(trainable_params, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
@@ -903,7 +1023,7 @@ def main():
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
if args.use_ema:
ema_unet.step(unet_lora_parameters)
ema_unet.step(trainable_params)
progress_bar.update(1)
global_step += 1
accelerator.log({"train_loss": train_loss}, step=global_step)
@@ -933,6 +1053,16 @@ def main():
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(
get_peft_model_state_dict(unwrapped_unet)
)
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
logger.info(f"Saved state to {save_path}")
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -959,45 +1089,22 @@ def main():
# The models need unwrapping because for compatibility in distributed training mode.
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unwrap_model(unet),
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
revision=args.revision,
variant=args.variant,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
original_image = download_image(args.val_image_url)
edited_images = []
if torch.backends.mps.is_available():
autocast_ctx = nullcontext()
else:
autocast_ctx = torch.autocast(accelerator.device.type)
log_validation(
pipeline,
args,
accelerator,
generator,
)
with autocast_ctx:
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"validation": wandb_table})
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())
@@ -1008,22 +1115,47 @@ def main():
# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = accelerator.unwrap_model(unet)
if args.use_ema:
ema_unet.copy_to(unet.parameters())
# store only LORA layers
unet = unet.to(torch.float32)
unwrapped_unet = unwrap_model(unet)
unet_lora_state_dict = convert_state_dict_to_diffusers(get_peft_model_state_dict(unwrapped_unet))
StableDiffusionInstructPix2PixPipeline.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_state_dict,
safe_serialization=True,
)
pipeline = StableDiffusionInstructPix2PixPipeline.from_pretrained(
args.pretrained_model_name_or_path,
text_encoder=accelerator.unwrap_model(text_encoder),
vae=accelerator.unwrap_model(vae),
unet=unet,
text_encoder=unwrap_model(text_encoder),
vae=unwrap_model(vae),
unet=unwrap_model(unet),
revision=args.revision,
variant=args.variant,
)
# store only LORA layers
unet.save_attn_procs(args.output_dir)
pipeline.load_lora_weights(args.output_dir)
images = None
if (args.val_image_url is not None) and (args.validation_prompt is not None):
images = log_validation(
pipeline,
args,
accelerator,
generator,
)
if args.push_to_hub:
save_model_card(
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
dataset_name=args.dataset_name,
repo_folder=args.output_dir,
)
upload_folder(
repo_id=repo_id,
folder_path=args.output_dir,
@@ -1031,31 +1163,6 @@ def main():
ignore_patterns=["step_*", "epoch_*"],
)
if args.validation_prompt is not None:
edited_images = []
pipeline = pipeline.to(accelerator.device)
with torch.autocast(str(accelerator.device).replace(":0", "")):
for _ in range(args.num_validation_images):
edited_images.append(
pipeline(
args.validation_prompt,
image=original_image,
num_inference_steps=20,
image_guidance_scale=1.5,
guidance_scale=7,
generator=generator,
).images[0]
)
for tracker in accelerator.trackers:
if tracker.name == "wandb":
wandb_table = wandb.Table(columns=WANDB_TABLE_COL_NAMES)
for edited_image in edited_images:
wandb_table.add_data(
wandb.Image(original_image), wandb.Image(edited_image), args.validation_prompt
)
tracker.log({"test": wandb_table})
accelerator.end_training()