mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
* Move files to research-projects. * docs: add IP Adapter training instructions * Delete venv * Update examples/ip_adapter/tutorial_train_sdxl.py Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> * Cherry-picked commits and re-moved files to research_projects. * make style. * Update toctree and delete ip_adapter. * Nit Fix * Fix nit. * Fix nit. * Create training script for single GPU and set model format to .safetensors * Add sample inference script and restore _toctree * Restore toctree.yaml * fix spacing. * Update toctree.yaml --------- Co-authored-by: AMohamedAakhil <a.aakhilmohamed@gmail.com> Co-authored-by: BootesVoid <78485654+AMohamedAakhil@users.noreply.github.com> Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
446 lines
17 KiB
Python
446 lines
17 KiB
Python
import argparse
|
|
import itertools
|
|
import json
|
|
import os
|
|
import random
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from accelerate import Accelerator
|
|
from accelerate.utils import ProjectConfiguration
|
|
from ip_adapter.resampler import Resampler
|
|
from ip_adapter.utils import is_torch2_available
|
|
from PIL import Image
|
|
from torchvision import transforms
|
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
|
|
|
from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
|
|
|
|
|
|
if is_torch2_available():
|
|
from ip_adapter.attention_processor import AttnProcessor2_0 as AttnProcessor
|
|
from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor
|
|
else:
|
|
from ip_adapter.attention_processor import AttnProcessor, IPAttnProcessor
|
|
|
|
|
|
# Dataset
|
|
class MyDataset(torch.utils.data.Dataset):
|
|
def __init__(
|
|
self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""
|
|
):
|
|
super().__init__()
|
|
|
|
self.tokenizer = tokenizer
|
|
self.size = size
|
|
self.i_drop_rate = i_drop_rate
|
|
self.t_drop_rate = t_drop_rate
|
|
self.ti_drop_rate = ti_drop_rate
|
|
self.image_root_path = image_root_path
|
|
|
|
self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}]
|
|
|
|
self.transform = transforms.Compose(
|
|
[
|
|
transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR),
|
|
transforms.CenterCrop(self.size),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize([0.5], [0.5]),
|
|
]
|
|
)
|
|
self.clip_image_processor = CLIPImageProcessor()
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.data[idx]
|
|
text = item["text"]
|
|
image_file = item["image_file"]
|
|
|
|
# read image
|
|
raw_image = Image.open(os.path.join(self.image_root_path, image_file))
|
|
image = self.transform(raw_image.convert("RGB"))
|
|
clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values
|
|
|
|
# drop
|
|
drop_image_embed = 0
|
|
rand_num = random.random()
|
|
if rand_num < self.i_drop_rate:
|
|
drop_image_embed = 1
|
|
elif rand_num < (self.i_drop_rate + self.t_drop_rate):
|
|
text = ""
|
|
elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate):
|
|
text = ""
|
|
drop_image_embed = 1
|
|
# get text and tokenize
|
|
text_input_ids = self.tokenizer(
|
|
text,
|
|
max_length=self.tokenizer.model_max_length,
|
|
padding="max_length",
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
).input_ids
|
|
|
|
return {
|
|
"image": image,
|
|
"text_input_ids": text_input_ids,
|
|
"clip_image": clip_image,
|
|
"drop_image_embed": drop_image_embed,
|
|
}
|
|
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
|
|
def collate_fn(data):
|
|
images = torch.stack([example["image"] for example in data])
|
|
text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0)
|
|
clip_images = torch.cat([example["clip_image"] for example in data], dim=0)
|
|
drop_image_embeds = [example["drop_image_embed"] for example in data]
|
|
|
|
return {
|
|
"images": images,
|
|
"text_input_ids": text_input_ids,
|
|
"clip_images": clip_images,
|
|
"drop_image_embeds": drop_image_embeds,
|
|
}
|
|
|
|
|
|
class IPAdapter(torch.nn.Module):
|
|
"""IP-Adapter"""
|
|
|
|
def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None):
|
|
super().__init__()
|
|
self.unet = unet
|
|
self.image_proj_model = image_proj_model
|
|
self.adapter_modules = adapter_modules
|
|
|
|
if ckpt_path is not None:
|
|
self.load_from_checkpoint(ckpt_path)
|
|
|
|
def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds):
|
|
ip_tokens = self.image_proj_model(image_embeds)
|
|
encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1)
|
|
# Predict the noise residual
|
|
noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
|
return noise_pred
|
|
|
|
def load_from_checkpoint(self, ckpt_path: str):
|
|
# Calculate original checksums
|
|
orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
|
orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
|
|
|
state_dict = torch.load(ckpt_path, map_location="cpu")
|
|
|
|
# Check if 'latents' exists in both the saved state_dict and the current model's state_dict
|
|
strict_load_image_proj_model = True
|
|
if "latents" in state_dict["image_proj"] and "latents" in self.image_proj_model.state_dict():
|
|
# Check if the shapes are mismatched
|
|
if state_dict["image_proj"]["latents"].shape != self.image_proj_model.state_dict()["latents"].shape:
|
|
print(f"Shapes of 'image_proj.latents' in checkpoint {ckpt_path} and current model do not match.")
|
|
print("Removing 'latents' from checkpoint and loading the rest of the weights.")
|
|
del state_dict["image_proj"]["latents"]
|
|
strict_load_image_proj_model = False
|
|
|
|
# Load state dict for image_proj_model and adapter_modules
|
|
self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=strict_load_image_proj_model)
|
|
self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True)
|
|
|
|
# Calculate new checksums
|
|
new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()]))
|
|
new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()]))
|
|
|
|
# Verify if the weights have changed
|
|
assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!"
|
|
assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!"
|
|
|
|
print(f"Successfully loaded weights from checkpoint {ckpt_path}")
|
|
|
|
|
|
def parse_args():
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
|
parser.add_argument(
|
|
"--pretrained_model_name_or_path",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
|
)
|
|
parser.add_argument(
|
|
"--pretrained_ip_adapter_path",
|
|
type=str,
|
|
default=None,
|
|
help="Path to pretrained ip adapter model. If not specified weights are initialized randomly.",
|
|
)
|
|
parser.add_argument(
|
|
"--num_tokens",
|
|
type=int,
|
|
default=16,
|
|
help="Number of tokens to query from the CLIP image encoding.",
|
|
)
|
|
parser.add_argument(
|
|
"--data_json_file",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="Training data",
|
|
)
|
|
parser.add_argument(
|
|
"--data_root_path",
|
|
type=str,
|
|
default="",
|
|
required=True,
|
|
help="Training data root path",
|
|
)
|
|
parser.add_argument(
|
|
"--image_encoder_path",
|
|
type=str,
|
|
default=None,
|
|
required=True,
|
|
help="Path to CLIP image encoder",
|
|
)
|
|
parser.add_argument(
|
|
"--output_dir",
|
|
type=str,
|
|
default="sd-ip_adapter",
|
|
help="The output directory where the model predictions and checkpoints will be written.",
|
|
)
|
|
parser.add_argument(
|
|
"--logging_dir",
|
|
type=str,
|
|
default="logs",
|
|
help=(
|
|
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
|
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--resolution",
|
|
type=int,
|
|
default=512,
|
|
help=("The resolution for input images"),
|
|
)
|
|
parser.add_argument(
|
|
"--learning_rate",
|
|
type=float,
|
|
default=1e-4,
|
|
help="Learning rate to use.",
|
|
)
|
|
parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
|
parser.add_argument("--num_train_epochs", type=int, default=100)
|
|
parser.add_argument(
|
|
"--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader."
|
|
)
|
|
parser.add_argument(
|
|
"--dataloader_num_workers",
|
|
type=int,
|
|
default=0,
|
|
help=(
|
|
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--save_steps",
|
|
type=int,
|
|
default=2000,
|
|
help=("Save a checkpoint of the training state every X updates"),
|
|
)
|
|
parser.add_argument(
|
|
"--mixed_precision",
|
|
type=str,
|
|
default=None,
|
|
choices=["no", "fp16", "bf16"],
|
|
help=(
|
|
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
|
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
|
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
|
),
|
|
)
|
|
parser.add_argument(
|
|
"--report_to",
|
|
type=str,
|
|
default="tensorboard",
|
|
help=(
|
|
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
|
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
|
),
|
|
)
|
|
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
|
|
|
args = parser.parse_args()
|
|
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
|
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
|
args.local_rank = env_local_rank
|
|
|
|
return args
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
logging_dir = Path(args.output_dir, args.logging_dir)
|
|
|
|
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
|
|
|
|
accelerator = Accelerator(
|
|
mixed_precision=args.mixed_precision,
|
|
log_with=args.report_to,
|
|
project_config=accelerator_project_config,
|
|
)
|
|
|
|
if accelerator.is_main_process:
|
|
if args.output_dir is not None:
|
|
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
# Load scheduler, tokenizer and models.
|
|
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
|
tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer")
|
|
text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder")
|
|
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae")
|
|
unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet")
|
|
image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path)
|
|
# freeze parameters of models to save more memory
|
|
unet.requires_grad_(False)
|
|
vae.requires_grad_(False)
|
|
text_encoder.requires_grad_(False)
|
|
image_encoder.requires_grad_(False)
|
|
|
|
# ip-adapter-plus
|
|
image_proj_model = Resampler(
|
|
dim=unet.config.cross_attention_dim,
|
|
depth=4,
|
|
dim_head=64,
|
|
heads=12,
|
|
num_queries=args.num_tokens,
|
|
embedding_dim=image_encoder.config.hidden_size,
|
|
output_dim=unet.config.cross_attention_dim,
|
|
ff_mult=4,
|
|
)
|
|
# init adapter modules
|
|
attn_procs = {}
|
|
unet_sd = unet.state_dict()
|
|
for name in unet.attn_processors.keys():
|
|
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
|
if name.startswith("mid_block"):
|
|
hidden_size = unet.config.block_out_channels[-1]
|
|
elif name.startswith("up_blocks"):
|
|
block_id = int(name[len("up_blocks.")])
|
|
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
|
elif name.startswith("down_blocks"):
|
|
block_id = int(name[len("down_blocks.")])
|
|
hidden_size = unet.config.block_out_channels[block_id]
|
|
if cross_attention_dim is None:
|
|
attn_procs[name] = AttnProcessor()
|
|
else:
|
|
layer_name = name.split(".processor")[0]
|
|
weights = {
|
|
"to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
|
|
"to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
|
|
}
|
|
attn_procs[name] = IPAttnProcessor(
|
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=args.num_tokens
|
|
)
|
|
attn_procs[name].load_state_dict(weights)
|
|
unet.set_attn_processor(attn_procs)
|
|
adapter_modules = torch.nn.ModuleList(unet.attn_processors.values())
|
|
|
|
ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules, args.pretrained_ip_adapter_path)
|
|
|
|
weight_dtype = torch.float32
|
|
if accelerator.mixed_precision == "fp16":
|
|
weight_dtype = torch.float16
|
|
elif accelerator.mixed_precision == "bf16":
|
|
weight_dtype = torch.bfloat16
|
|
# unet.to(accelerator.device, dtype=weight_dtype)
|
|
vae.to(accelerator.device, dtype=weight_dtype)
|
|
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
|
|
|
# optimizer
|
|
params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters())
|
|
optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay)
|
|
|
|
# dataloader
|
|
train_dataset = MyDataset(
|
|
args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path
|
|
)
|
|
train_dataloader = torch.utils.data.DataLoader(
|
|
train_dataset,
|
|
shuffle=True,
|
|
collate_fn=collate_fn,
|
|
batch_size=args.train_batch_size,
|
|
num_workers=args.dataloader_num_workers,
|
|
)
|
|
|
|
# Prepare everything with our `accelerator`.
|
|
ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader)
|
|
|
|
global_step = 0
|
|
for epoch in range(0, args.num_train_epochs):
|
|
begin = time.perf_counter()
|
|
for step, batch in enumerate(train_dataloader):
|
|
load_data_time = time.perf_counter() - begin
|
|
with accelerator.accumulate(ip_adapter):
|
|
# Convert images to latent space
|
|
with torch.no_grad():
|
|
latents = vae.encode(
|
|
batch["images"].to(accelerator.device, dtype=weight_dtype)
|
|
).latent_dist.sample()
|
|
latents = latents * vae.config.scaling_factor
|
|
|
|
# Sample noise that we'll add to the latents
|
|
noise = torch.randn_like(latents)
|
|
bsz = latents.shape[0]
|
|
# Sample a random timestep for each image
|
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
|
timesteps = timesteps.long()
|
|
|
|
# Add noise to the latents according to the noise magnitude at each timestep
|
|
# (this is the forward diffusion process)
|
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
clip_images = []
|
|
for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]):
|
|
if drop_image_embed == 1:
|
|
clip_images.append(torch.zeros_like(clip_image))
|
|
else:
|
|
clip_images.append(clip_image)
|
|
clip_images = torch.stack(clip_images, dim=0)
|
|
with torch.no_grad():
|
|
image_embeds = image_encoder(
|
|
clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True
|
|
).hidden_states[-2]
|
|
|
|
with torch.no_grad():
|
|
encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0]
|
|
|
|
noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds)
|
|
|
|
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
|
|
|
# Gather the losses across all processes for logging (if we use distributed training).
|
|
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item()
|
|
|
|
# Backpropagate
|
|
accelerator.backward(loss)
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
if accelerator.is_main_process:
|
|
print(
|
|
"Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format(
|
|
epoch, step, load_data_time, time.perf_counter() - begin, avg_loss
|
|
)
|
|
)
|
|
|
|
global_step += 1
|
|
|
|
if global_step % args.save_steps == 0:
|
|
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
|
accelerator.save_state(save_path)
|
|
|
|
begin = time.perf_counter()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|