1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Move IP Adapter Scripts to research project (#9960)

* Move files to research-projects.

* docs: add IP Adapter training instructions

* Delete venv

* Update examples/ip_adapter/tutorial_train_sdxl.py

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>

* Cherry-picked commits and re-moved files
to research_projects.

* make style.

* Update toctree and delete ip_adapter.

* Nit Fix

* Fix nit.

* Fix nit.

* Create training script for single GPU and set
model format to .safetensors

* Add sample inference script and restore _toctree

* Restore toctree.yaml

* fix spacing.

* Update toctree.yaml

---------

Co-authored-by: AMohamedAakhil <a.aakhilmohamed@gmail.com>
Co-authored-by: BootesVoid <78485654+AMohamedAakhil@users.noreply.github.com>
Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Parag Ekbote
2024-11-20 00:07:22 +05:30
committed by GitHub
parent ea40933f36
commit cc7d88f247
6 changed files with 2032 additions and 0 deletions

View File

@@ -0,0 +1,226 @@
# IP Adapter Training Example
[IP Adapter](https://arxiv.org/abs/2308.06721) is a novel approach designed to enhance text-to-image models such as Stable Diffusion by enabling them to generate images based on image prompts rather than text prompts alone. Unlike traditional methods that rely solely on complex text prompts, IP Adapter introduces the concept of using image prompts, leveraging the idea that "an image is worth a thousand words." By decoupling cross-attention layers for text and image features, IP Adapter effectively integrates image prompts into the generation process without the need for extensive fine-tuning or large computing resources.
## Training locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
**Important**
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:
```bash
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
Then cd in the example folder and run
```bash
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
Certainly! Below is the documentation in pure Markdown format:
### Accelerate Launch Command Documentation
#### Description:
The Accelerate launch command is used to train a model using multiple GPUs and mixed precision training. It launches the training script `tutorial_train_ip-adapter.py` with specified parameters and configurations.
#### Usage Example:
```
accelerate launch --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
--mixed_precision="fp16" \
--resolution=512 \
--train_batch_size=8 \
--dataloader_num_workers=4 \
--learning_rate=1e-04 \
--weight_decay=0.01 \
--output_dir="{output_dir}" \
--save_steps=10000
```
### Multi-GPU Script:
```
accelerate launch --num_processes 8 --multi_gpu --mixed_precision "fp16" \
tutorial_train_ip-adapter.py \
--pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5/" \
--image_encoder_path="{image_encoder_path}" \
--data_json_file="{data.json}" \
--data_root_path="{image_path}" \
--mixed_precision="fp16" \
--resolution=512 \
--train_batch_size=8 \
--dataloader_num_workers=4 \
--learning_rate=1e-04 \
--weight_decay=0.01 \
--output_dir="{output_dir}" \
--save_steps=10000
```
#### Parameters:
- `--num_processes`: Number of processes to launch for distributed training (in this example, 8 processes).
- `--multi_gpu`: Flag indicating the usage of multiple GPUs for training.
- `--mixed_precision "fp16"`: Enables mixed precision training with 16-bit floating-point precision.
- `tutorial_train_ip-adapter.py`: Name of the training script to be executed.
- `--pretrained_model_name_or_path`: Path or identifier for a pretrained model.
- `--image_encoder_path`: Path to the CLIP image encoder.
- `--data_json_file`: Path to the training data in JSON format.
- `--data_root_path`: Root path where training images are located.
- `--resolution`: Resolution of input images (512x512 in this example).
- `--train_batch_size`: Batch size for training data (8 in this example).
- `--dataloader_num_workers`: Number of subprocesses for data loading (4 in this example).
- `--learning_rate`: Learning rate for training (1e-04 in this example).
- `--weight_decay`: Weight decay for regularization (0.01 in this example).
- `--output_dir`: Directory to save model checkpoints and predictions.
- `--save_steps`: Frequency of saving checkpoints during training (10000 in this example).
### Inference
#### Description:
The provided inference code is used to load a trained model checkpoint and extract the components related to image projection and IP (Image Processing) adapter. These components are then saved into a binary file for later use in inference.
#### Usage Example:
```python
from safetensors.torch import load_file, save_file
# Load the trained model checkpoint in safetensors format
ckpt = "checkpoint-50000/pytorch_model.safetensors"
sd = load_file(ckpt) # Using safetensors load function
# Extract image projection and IP adapter components
image_proj_sd = {}
ip_sd = {}
for k in sd:
if k.startswith("unet"):
pass # Skip unet-related keys
elif k.startswith("image_proj_model"):
image_proj_sd[k.replace("image_proj_model.", "")] = sd[k]
elif k.startswith("adapter_modules"):
ip_sd[k.replace("adapter_modules.", "")] = sd[k]
# Save the components into separate safetensors files
save_file(image_proj_sd, "image_proj.safetensors")
save_file(ip_sd, "ip_adapter.safetensors")
```
### Sample Inference Script using the CLIP Model
```python
import torch
from safetensors.torch import load_file
from transformers import CLIPProcessor, CLIPModel # Using the Hugging Face CLIP model
# Load model components from safetensors
image_proj_ckpt = "image_proj.safetensors"
ip_adapter_ckpt = "ip_adapter.safetensors"
# Load the saved weights
image_proj_sd = load_file(image_proj_ckpt)
ip_adapter_sd = load_file(ip_adapter_ckpt)
# Define the model Parameters
class ImageProjectionModel(torch.nn.Module):
def __init__(self, input_dim=768, output_dim=512): # CLIP's default embedding size is 768
super().__init__()
self.model = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.model(x)
class IPAdapterModel(torch.nn.Module):
def __init__(self, input_dim=512, output_dim=10): # Example for 10 classes
super().__init__()
self.model = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.model(x)
# Initialize models
image_proj_model = ImageProjectionModel()
ip_adapter_model = IPAdapterModel()
# Load weights into models
image_proj_model.load_state_dict(image_proj_sd)
ip_adapter_model.load_state_dict(ip_adapter_sd)
# Set models to evaluation mode
image_proj_model.eval()
ip_adapter_model.eval()
#Inference pipeline
def inference(image_tensor):
"""
Run inference using the loaded models.
Args:
image_tensor: Preprocessed image tensor from CLIPProcessor
Returns:
Final inference results
"""
with torch.no_grad():
# Step 1: Project the image features
image_proj = image_proj_model(image_tensor)
# Step 2: Pass the projected features through the IP Adapter
result = ip_adapter_model(image_proj)
return result
# Using CLIP for image preprocessing
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
#Image file path
image_path = "path/to/image.jpg"
# Preprocess the image
inputs = processor(images=image_path, return_tensors="pt")
image_features = clip_model.get_image_features(inputs["pixel_values"])
# Normalize the image features as per CLIP's recommendations
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Run inference
output = inference(image_features)
print("Inference output:", output)
```
#### Parameters:
- `ckpt`: Path to the trained model checkpoint file.
- `map_location="cpu"`: Specifies that the model should be loaded onto the CPU.
- `image_proj_sd`: Dictionary to store the components related to image projection.
- `ip_sd`: Dictionary to store the components related to the IP adapter.
- `"unet"`, `"image_proj_model"`, `"adapter_modules"`: Prefixes indicating components of the model.

View File

@@ -0,0 +1,4 @@
accelerate
torchvision
transformers>=4.25.1
ip_adapter

View File

@@ -0,0 +1,415 @@
import argparse
import itertools
import json
import os
import random
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from ip_adapter.attention_processor_faceid import LoRAAttnProcessor, LoRAIPAttnProcessor
from ip_adapter.ip_adapter_faceid import MLPProjModel
from PIL import Image
from torchvision import transforms
from transformers import CLIPTextModel, CLIPTokenizer
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(
open(json_file)
) # list of dict: [{"image_file": "1.png", "id_embed_file": "faceid.bin"}]
self.transform = transforms.Compose(
[
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
face_id_embed = torch.load(item["id_embed_file"], map_location="cpu")
face_id_embed = torch.from_numpy(face_id_embed)
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
if drop_image_embed:
face_id_embed = torch.zeros_like(face_id_embed)
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"face_id_embed": face_id_embed,
"drop_image_embed": drop_image_embed,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
face_id_embed = torch.stack([example["face_id_embed"] for example in data])
drop_image_embeds = [example["drop_image_embed"] for example in data]
return {
"images": images,
"text_input_ids": text_input_ids,
"face_id_embed": face_id_embed,
"drop_image_embeds": drop_image_embeds,
}
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=("The resolution for input images"),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=("Save a checkpoint of the training state every X updates"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
# image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
# image_encoder.requires_grad_(False)
# ip-adapter
image_proj_model = MLPProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
id_embeddings_dim=512,
num_tokens=4,
)
# init adapter modules
lora_rank = 128
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
)
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = LoRAIPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=lora_rank
)
attn_procs[name].load_state_dict(weights, strict=False)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# image_encoder.to(accelerator.device, dtype=weight_dtype)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
# dataloader
train_dataset = MyDataset(
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(
batch["images"].to(accelerator.device, dtype=weight_dtype)
).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
image_embeds = batch["face_id_embed"].to(accelerator.device, dtype=weight_dtype)
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagate
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
print(
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
)
)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
begin = time.perf_counter()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,422 @@
import argparse
import itertools
import json
import os
import random
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from ip_adapter.ip_adapter import ImageProjModel
from ip_adapter.utils import is_torch2_available
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
if is_torch2_available():
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
else:
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
self.transform = transforms.Compose(
[
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.clip_image_processor = CLIPImageProcessor()
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
drop_image_embeds = [example["drop_image_embed"] for example in data]
return {
"images": images,
"text_input_ids": text_input_ids,
"clip_images": clip_images,
"drop_image_embeds": drop_image_embeds,
}
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=("The resolution for input images"),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=("Save a checkpoint of the training state every X updates"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
# ip-adapter
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=4,
)
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
# dataloader
train_dataset = MyDataset(
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(
batch["images"].to(accelerator.device, dtype=weight_dtype)
).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
with torch.no_grad():
image_embeds = image_encoder(
batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
).image_embeds
image_embeds_ = []
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
if drop_image_embed == 1:
image_embeds_.append(torch.zeros_like(image_embed))
else:
image_embeds_.append(image_embed)
image_embeds = torch.stack(image_embeds_)
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagate
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
print(
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
)
)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
begin = time.perf_counter()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,445 @@
import argparse
import itertools
import json
import os
import random
import time
from pathlib import Path
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from ip_adapter.resampler import Resampler
from ip_adapter.utils import is_torch2_available
from PIL import Image
from torchvision import transforms
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
if is_torch2_available():
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
else:
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
):
super().__init__()
self.tokenizer = tokenizer
self.size = size
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
self.transform = transforms.Compose(
[
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(self.size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.clip_image_processor = CLIPImageProcessor()
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
image = self.transform(raw_image.convert("RGB"))
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
drop_image_embeds = [example["drop_image_embed"] for example in data]
return {
"images": images,
"text_input_ids": text_input_ids,
"clip_images": clip_images,
"drop_image_embeds": drop_image_embeds,
}
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
strict_load_image_proj_model = True
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
# Check if the shapes are mismatched
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
del state_dict["image_proj"]["latents"]
strict_load_image_proj_model = False
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--num_tokens",
type=int,
default=16,
help="Number of tokens to query from the CLIP image encoding.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=("The resolution for input images"),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=("Save a checkpoint of the training state every X updates"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
image_encoder.requires_grad_(False)
# ip-adapter-plus
image_proj_model = Resampler(
dim=unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=args.num_tokens,
embedding_dim=image_encoder.config.hidden_size,
output_dim=unet.config.cross_attention_dim,
ff_mult=4,
)
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
# dataloader
train_dataset = MyDataset(
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
latents = vae.encode(
batch["images"].to(accelerator.device, dtype=weight_dtype)
).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
clip_images = []
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
if drop_image_embed == 1:
clip_images.append(torch.zeros_like(clip_image))
else:
clip_images.append(clip_image)
clip_images = torch.stack(clip_images, dim=0)
with torch.no_grad():
image_embeds = image_encoder(
clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
).hidden_states[-2]
with torch.no_grad():
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagate
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
print(
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
)
)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
begin = time.perf_counter()
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,520 @@
import argparse
import itertools
import json
import os
import random
import time
from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator
from accelerate.utils import ProjectConfiguration
from ip_adapter.ip_adapter import ImageProjModel
from ip_adapter.utils import is_torch2_available
from PIL import Image
from torchvision import transforms
from transformers import (
CLIPImageProcessor,
CLIPTextModel,
CLIPTextModelWithProjection,
CLIPTokenizer,
CLIPVisionModelWithProjection,
)
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
if is_torch2_available():
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
else:
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
# Dataset
class MyDataset(torch.utils.data.Dataset):
def __init__(
self,
json_file,
tokenizer,
tokenizer_2,
size=1024,
center_crop=True,
t_drop_rate=0.05,
i_drop_rate=0.05,
ti_drop_rate=0.05,
image_root_path="",
):
super().__init__()
self.tokenizer = tokenizer
self.tokenizer_2 = tokenizer_2
self.size = size
self.center_crop = center_crop
self.i_drop_rate = i_drop_rate
self.t_drop_rate = t_drop_rate
self.ti_drop_rate = ti_drop_rate
self.image_root_path = image_root_path
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
self.transform = transforms.Compose(
[
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.clip_image_processor = CLIPImageProcessor()
def __getitem__(self, idx):
item = self.data[idx]
text = item["text"]
image_file = item["image_file"]
# read image
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
# original size
original_width, original_height = raw_image.size
original_size = torch.tensor([original_height, original_width])
image_tensor = self.transform(raw_image.convert("RGB"))
# random crop
delta_h = image_tensor.shape[1] - self.size
delta_w = image_tensor.shape[2] - self.size
assert not all([delta_h, delta_w])
if self.center_crop:
top = delta_h // 2
left = delta_w // 2
else:
top = np.random.randint(0, delta_h + 1)
left = np.random.randint(0, delta_w + 1)
image = transforms.functional.crop(image_tensor, top=top, left=left, height=self.size, width=self.size)
crop_coords_top_left = torch.tensor([top, left])
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
# drop
drop_image_embed = 0
rand_num = random.random()
if rand_num < self.i_drop_rate:
drop_image_embed = 1
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
text = ""
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
text = ""
drop_image_embed = 1
# get text and tokenize
text_input_ids = self.tokenizer(
text,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
text_input_ids_2 = self.tokenizer_2(
text,
max_length=self.tokenizer_2.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
).input_ids
return {
"image": image,
"text_input_ids": text_input_ids,
"text_input_ids_2": text_input_ids_2,
"clip_image": clip_image,
"drop_image_embed": drop_image_embed,
"original_size": original_size,
"crop_coords_top_left": crop_coords_top_left,
"target_size": torch.tensor([self.size, self.size]),
}
def __len__(self):
return len(self.data)
def collate_fn(data):
images = torch.stack([example["image"] for example in data])
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
text_input_ids_2 = torch.cat([example["text_input_ids_2"] for example in data], dim=0)
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
drop_image_embeds = [example["drop_image_embed"] for example in data]
original_size = torch.stack([example["original_size"] for example in data])
crop_coords_top_left = torch.stack([example["crop_coords_top_left"] for example in data])
target_size = torch.stack([example["target_size"] for example in data])
return {
"images": images,
"text_input_ids": text_input_ids,
"text_input_ids_2": text_input_ids_2,
"clip_images": clip_images,
"drop_image_embeds": drop_image_embeds,
"original_size": original_size,
"crop_coords_top_left": crop_coords_top_left,
"target_size": target_size,
}
class IPAdapter(torch.nn.Module):
"""IP-Adapter"""
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
super().__init__()
self.unet = unet
self.image_proj_model = image_proj_model
self.adapter_modules = adapter_modules
if ckpt_path is not None:
self.load_from_checkpoint(ckpt_path)
def forward(self, noisy_latents, timesteps, encoder_hidden_states, unet_added_cond_kwargs, image_embeds):
ip_tokens = self.image_proj_model(image_embeds)
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
# Predict the noise residual
noise_pred = self.unet(
noisy_latents, timesteps, encoder_hidden_states, added_cond_kwargs=unet_added_cond_kwargs
).sample
return noise_pred
def load_from_checkpoint(self, ckpt_path: str):
# Calculate original checksums
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
state_dict = torch.load(ckpt_path, map_location="cpu")
# Load state dict for image_proj_model and adapter_modules
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True)
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
# Calculate new checksums
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
# Verify if the weights have changed
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--pretrained_ip_adapter_path",
type=str,
default=None,
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
)
parser.add_argument(
"--data_json_file",
type=str,
default=None,
required=True,
help="Training data",
)
parser.add_argument(
"--data_root_path",
type=str,
default="",
required=True,
help="Training data root path",
)
parser.add_argument(
"--image_encoder_path",
type=str,
default=None,
required=True,
help="Path to CLIP image encoder",
)
parser.add_argument(
"--output_dir",
type=str,
default="sd-ip_adapter",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--resolution",
type=int,
default=512,
help=("The resolution for input images"),
)
parser.add_argument(
"--learning_rate",
type=float,
default=1e-4,
help="Learning rate to use.",
)
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--num_train_epochs", type=int, default=100)
parser.add_argument(
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--noise_offset", type=float, default=None, help="noise offset")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument(
"--save_steps",
type=int,
default=2000,
help=("Save a checkpoint of the training state every X updates"),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
return args
def main():
args = parse_args()
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load scheduler, tokenizer and models.
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
tokenizer_2 = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer_2")
text_encoder_2 = CLIPTextModelWithProjection.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
# freeze parameters of models to save more memory
unet.requires_grad_(False)
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
# ip-adapter
num_tokens = 4
image_proj_model = ImageProjModel(
cross_attention_dim=unet.config.cross_attention_dim,
clip_embeddings_dim=image_encoder.config.projection_dim,
clip_extra_context_tokens=num_tokens,
)
# init adapter modules
attn_procs = {}
unet_sd = unet.state_dict()
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
layer_name = name.split(".processor")[0]
weights = {
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
}
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens
)
attn_procs[name].load_state_dict(weights)
unet.set_attn_processor(attn_procs)
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device) # use fp32
text_encoder.to(accelerator.device, dtype=weight_dtype)
text_encoder_2.to(accelerator.device, dtype=weight_dtype)
image_encoder.to(accelerator.device, dtype=weight_dtype)
# optimizer
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
# dataloader
train_dataset = MyDataset(
args.data_json_file,
tokenizer=tokenizer,
tokenizer_2=tokenizer_2,
size=args.resolution,
image_root_path=args.data_root_path,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Prepare everything with our `accelerator`.
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
global_step = 0
for epoch in range(0, args.num_train_epochs):
begin = time.perf_counter()
for step, batch in enumerate(train_dataloader):
load_data_time = time.perf_counter() - begin
with accelerator.accumulate(ip_adapter):
# Convert images to latent space
with torch.no_grad():
# vae of sdxl should use fp32
latents = vae.encode(batch["images"].to(accelerator.device, dtype=vae.dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
latents = latents.to(accelerator.device, dtype=weight_dtype)
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
if args.noise_offset:
# https://www.crosslabs.org//blog/diffusion-with-offset-noise
noise += args.noise_offset * torch.randn((latents.shape[0], latents.shape[1], 1, 1)).to(
accelerator.device, dtype=weight_dtype
)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
with torch.no_grad():
image_embeds = image_encoder(
batch["clip_images"].to(accelerator.device, dtype=weight_dtype)
).image_embeds
image_embeds_ = []
for image_embed, drop_image_embed in zip(image_embeds, batch["drop_image_embeds"]):
if drop_image_embed == 1:
image_embeds_.append(torch.zeros_like(image_embed))
else:
image_embeds_.append(image_embed)
image_embeds = torch.stack(image_embeds_)
with torch.no_grad():
encoder_output = text_encoder(
batch["text_input_ids"].to(accelerator.device), output_hidden_states=True
)
text_embeds = encoder_output.hidden_states[-2]
encoder_output_2 = text_encoder_2(
batch["text_input_ids_2"].to(accelerator.device), output_hidden_states=True
)
pooled_text_embeds = encoder_output_2[0]
text_embeds_2 = encoder_output_2.hidden_states[-2]
text_embeds = torch.concat([text_embeds, text_embeds_2], dim=-1) # concat
# add cond
add_time_ids = [
batch["original_size"].to(accelerator.device),
batch["crop_coords_top_left"].to(accelerator.device),
batch["target_size"].to(accelerator.device),
]
add_time_ids = torch.cat(add_time_ids, dim=1).to(accelerator.device, dtype=weight_dtype)
unet_added_cond_kwargs = {"text_embeds": pooled_text_embeds, "time_ids": add_time_ids}
noise_pred = ip_adapter(noisy_latents, timesteps, text_embeds, unet_added_cond_kwargs, image_embeds)
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
# Backpropagate
accelerator.backward(loss)
optimizer.step()
optimizer.zero_grad()
if accelerator.is_main_process:
print(
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
)
)
global_step += 1
if global_step % args.save_steps == 0:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
begin = time.perf_counter()
if __name__ == "__main__":
main()