1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

Gligen training (#7906)

* add training code of gligen

* fix code quality tests.

---------

Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
Hzzone
2024-06-05 20:26:42 +08:00
committed by GitHub
parent 48207d6689
commit d3881f35b7
7 changed files with 1312 additions and 0 deletions

View File

@@ -0,0 +1,156 @@
# GLIGEN: Open-Set Grounded Text-to-Image Generation
These scripts contain the code to prepare the grounding data and train the GLIGEN model on COCO dataset.
### Install the requirements
```bash
conda create -n diffusers python==3.10
conda activate diffusers
pip install -r requirements.txt
```
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
```bash
accelerate config
```
Or for a default accelerate configuration without answering questions about your environment
```bash
accelerate config default
```
Or if your environment doesn't support an interactive shell e.g. a notebook
```python
from accelerate.utils import write_basic_config
write_basic_config()
```
### Prepare the training data
If you want to make your own grounding data, you need to install the requirements.
I used [RAM](https://github.com/xinyu1205/recognize-anything) to tag
images, [Grounding DINO](https://github.com/IDEA-Research/GroundingDINO/issues?q=refer) to detect objects,
and [BLIP2](https://huggingface.co/docs/transformers/en/model_doc/blip-2) to caption instances.
Only RAM needs to be installed manually:
```bash
pip install git+https://github.com/xinyu1205/recognize-anything.git --no-deps
```
Download the pre-trained model:
```bash
huggingface-cli download --resume-download xinyu1205/recognize_anything_model ram_swin_large_14m.pth
huggingface-cli download --resume-download IDEA-Research/grounding-dino-base
huggingface-cli download --resume-download Salesforce/blip2-flan-t5-xxl
huggingface-cli download --resume-download clip-vit-large-patch14
huggingface-cli download --resume-download masterful/gligen-1-4-generation-text-box
```
Make the training data on 8 GPUs:
```bash
torchrun --master_port 17673 --nproc_per_node=8 make_datasets.py \
--data_root /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--save_root /root/gligen_data \
--ram_checkpoint /root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth
```
You can download the COCO training data from
```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO coco_train2017.pth
```
It's in the format of
```json
[
...
{
'file_path': Path,
'annos': [
{
'caption': Instance
Caption,
'bbox': bbox
in
xyxy,
'text_embeddings_before_projection': CLIP
text
embedding
before
linear
projection
}
]
}
...
]
```
### Training commands
The training script is heavily based
on https://github.com/huggingface/diffusers/blob/main/examples/controlnet/train_controlnet.py
```bash
accelerate launch train_gligen_text.py \
--data_path /root/data/zhizhonghuang/coco_train2017.pth \
--image_path /mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017 \
--train_batch_size 8 \
--max_train_steps 100000 \
--checkpointing_steps 1000 \
--checkpoints_total_limit 10 \
--learning_rate 5e-5 \
--dataloader_num_workers 16 \
--mixed_precision fp16 \
--report_to wandb \
--tracker_project_name gligen \
--output_dir /root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO
```
I trained the model on 8 A100 GPUs for about 11 hours (at least 24GB GPU memory). The generated images will follow the
layout possibly at 50k iterations.
Note that although the pre-trained GLIGEN model has been loaded, the parameters of `fuser` and `position_net` have been reset (see line 420 in `train_gligen_text.py`)
The trained model can be downloaded from
```bash
huggingface-cli download --resume-download Hzzone/GLIGEN_COCO config.json diffusion_pytorch_model.safetensors
```
You can run `demo.ipynb` to visualize the generated images.
Example prompts:
```python
prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'
boxes = [[0.041015625, 0.548828125, 0.453125, 0.859375],
[0.525390625, 0.552734375, 0.93359375, 0.865234375],
[0.12890625, 0.015625, 0.412109375, 0.279296875],
[0.578125, 0.08203125, 0.857421875, 0.27734375]]
gligen_phrases = ['a green car', 'a blue truck', 'a red air balloon', 'a bird']
```
Example images:
![alt text](generated-images-100000-00.png)
### Citation
```
@article{li2023gligen,
title={GLIGEN: Open-Set Grounded Text-to-Image Generation},
author={Li, Yuheng and Liu, Haotian and Wu, Qingyang and Mu, Fangzhou and Yang, Jianwei and Gao, Jianfeng and Li, Chunyuan and Lee, Yong Jae},
journal={CVPR},
year={2023}
}
```

View File

@@ -0,0 +1,110 @@
import os
import random
import torch
import torchvision.transforms as transforms
from PIL import Image
def recalculate_box_and_verify_if_valid(x, y, w, h, image_size, original_image_size, min_box_size):
scale = image_size / min(original_image_size)
crop_y = (original_image_size[1] * scale - image_size) // 2
crop_x = (original_image_size[0] * scale - image_size) // 2
x0 = max(x * scale - crop_x, 0)
y0 = max(y * scale - crop_y, 0)
x1 = min((x + w) * scale - crop_x, image_size)
y1 = min((y + h) * scale - crop_y, image_size)
if (x1 - x0) * (y1 - y0) / (image_size * image_size) < min_box_size:
return False, (None, None, None, None)
return True, (x0, y0, x1, y1)
class COCODataset(torch.utils.data.Dataset):
def __init__(
self,
data_path,
image_path,
image_size=512,
min_box_size=0.01,
max_boxes_per_data=8,
tokenizer=None,
):
super().__init__()
self.min_box_size = min_box_size
self.max_boxes_per_data = max_boxes_per_data
self.image_size = image_size
self.image_path = image_path
self.tokenizer = tokenizer
self.transforms = transforms.Compose(
[
transforms.Resize(image_size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
self.data_list = torch.load(data_path, map_location="cpu")
def __getitem__(self, index):
if self.max_boxes_per_data > 99:
assert False, "Are you sure setting such large number of boxes per image?"
out = {}
data = self.data_list[index]
image = Image.open(os.path.join(self.image_path, data["file_path"])).convert("RGB")
original_image_size = image.size
out["pixel_values"] = self.transforms(image)
annos = data["annos"]
areas, valid_annos = [], []
for anno in annos:
# x, y, w, h = anno['bbox']
x0, y0, x1, y1 = anno["bbox"]
x, y, w, h = x0, y0, x1 - x0, y1 - y0
valid, (x0, y0, x1, y1) = recalculate_box_and_verify_if_valid(
x, y, w, h, self.image_size, original_image_size, self.min_box_size
)
if valid:
anno["bbox"] = [x0, y0, x1, y1]
areas.append((x1 - x0) * (y1 - y0))
valid_annos.append(anno)
# Sort according to area and choose the largest N objects
wanted_idxs = torch.tensor(areas).sort(descending=True)[1]
wanted_idxs = wanted_idxs[: self.max_boxes_per_data]
valid_annos = [valid_annos[i] for i in wanted_idxs]
out["boxes"] = torch.zeros(self.max_boxes_per_data, 4)
out["masks"] = torch.zeros(self.max_boxes_per_data)
out["text_embeddings_before_projection"] = torch.zeros(self.max_boxes_per_data, 768)
for i, anno in enumerate(valid_annos):
out["boxes"][i] = torch.tensor(anno["bbox"]) / self.image_size
out["masks"][i] = 1
out["text_embeddings_before_projection"][i] = anno["text_embeddings_before_projection"]
prob_drop_boxes = 0.1
if random.random() < prob_drop_boxes:
out["masks"][:] = 0
caption = random.choice(data["captions"])
prob_drop_captions = 0.5
if random.random() < prob_drop_captions:
caption = ""
caption = self.tokenizer(
caption,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
out["caption"] = caption
return out
def __len__(self):
return len(self.data_list)

View File

@@ -0,0 +1,201 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The autoreload extension is already loaded. To reload it, use:\n",
" %reload_ext autoreload\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
}
],
"source": [
"%load_ext autoreload\n",
"%autoreload 2\n",
"\n",
"import torch\n",
"from diffusers import StableDiffusionGLIGENTextImagePipeline, StableDiffusionGLIGENPipeline"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import diffusers\n",
"from diffusers import (\n",
" AutoencoderKL,\n",
" DDPMScheduler,\n",
" UNet2DConditionModel,\n",
" UniPCMultistepScheduler,\n",
" EulerDiscreteScheduler,\n",
")\n",
"from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer\n",
"# pretrained_model_name_or_path = 'masterful/gligen-1-4-generation-text-box'\n",
"\n",
"pretrained_model_name_or_path = '/root/data/zhizhonghuang/checkpoints/models--masterful--gligen-1-4-generation-text-box/snapshots/d2820dc1e9ba6ca082051ce79cfd3eb468ae2c83'\n",
"\n",
"tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder=\"tokenizer\")\n",
"noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder=\"scheduler\")\n",
"text_encoder = CLIPTextModel.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"text_encoder\"\n",
")\n",
"vae = AutoencoderKL.from_pretrained(\n",
" pretrained_model_name_or_path, subfolder=\"vae\"\n",
")\n",
"# unet = UNet2DConditionModel.from_pretrained(\n",
"# pretrained_model_name_or_path, subfolder=\"unet\"\n",
"# )\n",
"\n",
"noise_scheduler = EulerDiscreteScheduler.from_config(noise_scheduler.config)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"unet = UNet2DConditionModel.from_pretrained(\n",
" '/root/data/zhizhonghuang/ckpt/GLIGEN_Text_Retrain_COCO'\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"You have disabled the safety checker for <class 'diffusers.pipelines.stable_diffusion_gligen.pipeline_stable_diffusion_gligen.StableDiffusionGLIGENPipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .\n"
]
}
],
"source": [
"pipe = StableDiffusionGLIGENPipeline(\n",
" vae,\n",
" text_encoder,\n",
" tokenizer,\n",
" unet,\n",
" noise_scheduler,\n",
" safety_checker=None,\n",
" feature_extractor=None,\n",
")\n",
"pipe = pipe.to(\"cuda\")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# prompt = 'A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky'\n",
"# gen_boxes = [('a green car', [21, 281, 211, 159]), ('a blue truck', [269, 283, 209, 160]), ('a red air balloon', [66, 8, 145, 135]), ('a bird', [296, 42, 143, 100])]\n",
"\n",
"# prompt = 'A realistic top-down view of a wooden table with two apples on it'\n",
"# gen_boxes = [('a wooden table', [20, 148, 472, 216]), ('an apple', [150, 226, 100, 100]), ('an apple', [280, 226, 100, 100])]\n",
"\n",
"# prompt = 'A realistic scene of three skiers standing in a line on the snow near a palm tree'\n",
"# gen_boxes = [('a skier', [5, 152, 139, 168]), ('a skier', [278, 192, 121, 158]), ('a skier', [148, 173, 124, 155]), ('a palm tree', [404, 105, 103, 251])]\n",
"\n",
"prompt = 'An oil painting of a pink dolphin jumping on the left of a steam boat on the sea'\n",
"gen_boxes = [('a steam boat', [232, 225, 257, 149]), ('a jumping pink dolphin', [21, 249, 189, 123])]\n",
"\n",
"import numpy as np\n",
"\n",
"boxes = np.array([x[1] for x in gen_boxes])\n",
"boxes = boxes / 512\n",
"boxes[:, 2] = boxes[:, 0] + boxes[:, 2]\n",
"boxes[:, 3] = boxes[:, 1] + boxes[:, 3]\n",
"boxes = boxes.tolist()\n",
"gligen_phrases = [x[0] for x in gen_boxes]"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:683: FutureWarning: Accessing config attribute `in_channels` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'in_channels' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.in_channels'.\n",
" num_channels_latents = self.unet.in_channels\n",
"/root/miniconda/envs/densecaption/lib/python3.11/site-packages/diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py:716: FutureWarning: Accessing config attribute `cross_attention_dim` directly via 'UNet2DConditionModel' object attribute is deprecated. Please access 'cross_attention_dim' over 'UNet2DConditionModel's config object instead, e.g. 'unet.config.cross_attention_dim'.\n",
" max_objs, self.unet.cross_attention_dim, device=device, dtype=self.text_encoder.dtype\n",
"100%|██████████| 50/50 [01:21<00:00, 1.64s/it]\n"
]
}
],
"source": [
"images = pipe(\n",
" prompt=prompt,\n",
" gligen_phrases=gligen_phrases,\n",
" gligen_boxes=boxes,\n",
" gligen_scheduled_sampling_beta=1.0,\n",
" output_type=\"pil\",\n",
" num_inference_steps=50,\n",
" negative_prompt=\"artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate\",\n",
" num_images_per_prompt=16,\n",
").images"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"diffusers.utils.make_image_grid(images, 4, len(images)//4)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "densecaption",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@@ -0,0 +1,119 @@
import argparse
import os
import random
import torch
import torchvision
import torchvision.transforms as TS
from PIL import Image
from ram import inference_ram
from ram.models import ram
from tqdm import tqdm
from transformers import (
AutoModelForZeroShotObjectDetection,
AutoProcessor,
Blip2ForConditionalGeneration,
Blip2Processor,
CLIPTextModel,
CLIPTokenizer,
)
torch.autograd.set_grad_enabled(False)
if __name__ == "__main__":
parser = argparse.ArgumentParser("Caption Generation script", add_help=False)
parser.add_argument("--data_root", type=str, required=True, help="path to COCO")
parser.add_argument("--save_root", type=str, required=True, help="path to save")
parser.add_argument("--ram_checkpoint", type=str, required=True, help="path to save")
args = parser.parse_args()
# ram_checkpoint = '/root/.cache/huggingface/hub/models--xinyu1205--recognize_anything_model/snapshots/ebc52dc741e86466202a5ab8ab22eae6e7d48bf1/ram_swin_large_14m.pth'
# data_root = '/mnt/workspace/workgroup/zhizhonghuang/dataset/COCO/train2017'
# save_root = '/root/gligen_data'
box_threshold = 0.25
text_threshold = 0.2
import torch.distributed as dist
dist.init_process_group(backend="nccl", init_method="env://")
local_rank = torch.distributed.get_rank() % torch.cuda.device_count()
device = f"cuda:{local_rank}"
torch.cuda.set_device(local_rank)
ram_model = ram(pretrained=args.ram_checkpoint, image_size=384, vit="swin_l").cuda().eval()
ram_processor = TS.Compose(
[TS.Resize((384, 384)), TS.ToTensor(), TS.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])]
)
grounding_dino_processor = AutoProcessor.from_pretrained("IDEA-Research/grounding-dino-base")
grounding_dino_model = AutoModelForZeroShotObjectDetection.from_pretrained(
"IDEA-Research/grounding-dino-base"
).cuda()
blip2_processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xxl")
blip2_model = Blip2ForConditionalGeneration.from_pretrained(
"Salesforce/blip2-flan-t5-xxl", torch_dtype=torch.float16
).cuda()
clip_text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").cuda()
clip_tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
image_paths = [os.path.join(args.data_root, x) for x in os.listdir(args.data_root)]
random.shuffle(image_paths)
for image_path in tqdm.tqdm(image_paths):
pth_path = os.path.join(args.save_root, os.path.basename(image_path))
if os.path.exists(pth_path):
continue
sample = {"file_path": os.path.basename(image_path), "annos": []}
raw_image = Image.open(image_path).convert("RGB")
res = inference_ram(ram_processor(raw_image).unsqueeze(0).cuda(), ram_model)
text = res[0].replace(" |", ".")
inputs = grounding_dino_processor(images=raw_image, text=text, return_tensors="pt")
inputs = {k: v.cuda() for k, v in inputs.items()}
outputs = grounding_dino_model(**inputs)
results = grounding_dino_processor.post_process_grounded_object_detection(
outputs,
inputs["input_ids"],
box_threshold=box_threshold,
text_threshold=text_threshold,
target_sizes=[raw_image.size[::-1]],
)
boxes = results[0]["boxes"]
labels = results[0]["labels"]
scores = results[0]["scores"]
indices = torchvision.ops.nms(boxes, scores, 0.5)
boxes = boxes[indices]
category_names = [labels[i] for i in indices]
for i, bbox in enumerate(boxes):
bbox = bbox.tolist()
inputs = blip2_processor(images=raw_image.crop(bbox), return_tensors="pt")
inputs = {k: v.cuda().to(torch.float16) for k, v in inputs.items()}
outputs = blip2_model.generate(**inputs)
caption = blip2_processor.decode(outputs[0], skip_special_tokens=True)
inputs = clip_tokenizer(
caption,
padding="max_length",
max_length=clip_tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
inputs = {k: v.cuda() for k, v in inputs.items()}
text_embeddings_before_projection = clip_text_encoder(**inputs).pooler_output.squeeze(0)
sample["annos"].append(
{
"caption": caption,
"bbox": bbox,
"text_embeddings_before_projection": text_embeddings_before_projection,
}
)
torch.save(sample, pth_path)

View File

@@ -0,0 +1,11 @@
accelerate>=0.16.0
torchvision
transformers>=4.25.1
ftfy
tensorboard
Jinja2
diffusers
scipy
timm
fairscale
wandb

View File

@@ -0,0 +1,715 @@
# from accelerate.utils import write_basic_config
#
# write_basic_config()
import argparse
import logging
import math
import os
import shutil
from pathlib import Path
import accelerate
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from packaging import version
from tqdm.auto import tqdm
import diffusers
from diffusers import (
AutoencoderKL,
DDPMScheduler,
EulerDiscreteScheduler,
StableDiffusionGLIGENPipeline,
UNet2DConditionModel,
)
from diffusers.optimization import get_scheduler
from diffusers.utils import is_wandb_available, make_image_grid
from diffusers.utils.import_utils import is_xformers_available
from diffusers.utils.torch_utils import is_compiled_module
if is_wandb_available():
pass
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
# check_min_version("0.28.0.dev0")
logger = get_logger(__name__)
@torch.no_grad()
def log_validation(vae, text_encoder, tokenizer, unet, noise_scheduler, args, accelerator, step, weight_dtype):
if accelerator.is_main_process:
print("generate test images...")
unet = accelerator.unwrap_model(unet)
vae.to(accelerator.device, dtype=torch.float32)
pipeline = StableDiffusionGLIGENPipeline(
vae,
text_encoder,
tokenizer,
unet,
EulerDiscreteScheduler.from_config(noise_scheduler.config),
safety_checker=None,
feature_extractor=None,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=not accelerator.is_main_process)
if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()
if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = "A realistic image of landscape scene depicting a green car parking on the left of a blue truck, with a red air balloon and a bird in the sky"
boxes = [
[0.041015625, 0.548828125, 0.453125, 0.859375],
[0.525390625, 0.552734375, 0.93359375, 0.865234375],
[0.12890625, 0.015625, 0.412109375, 0.279296875],
[0.578125, 0.08203125, 0.857421875, 0.27734375],
]
gligen_phrases = ["a green car", "a blue truck", "a red air balloon", "a bird"]
images = pipeline(
prompt=prompt,
gligen_phrases=gligen_phrases,
gligen_boxes=boxes,
gligen_scheduled_sampling_beta=1.0,
output_type="pil",
num_inference_steps=50,
negative_prompt="artifacts, blurry, smooth texture, bad quality, distortions, unrealistic, distorted image, bad proportions, duplicate",
num_images_per_prompt=4,
generator=generator,
).images
os.makedirs(os.path.join(args.output_dir, "images"), exist_ok=True)
make_image_grid(images, 1, 4).save(
os.path.join(args.output_dir, "images", f"generated-images-{step:06d}-{accelerator.process_index:02d}.png")
)
vae.to(accelerator.device, dtype=weight_dtype)
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.")
parser.add_argument(
"--data_path",
type=str,
default="coco_train2017.pth",
help="Path to training dataset.",
)
parser.add_argument(
"--image_path",
type=str,
default="coco_train2017.pth",
help="Path to training images.",
)
parser.add_argument(
"--output_dir",
type=str,
default="controlnet-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=0, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. "
"In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference."
"Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components."
"See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step"
"instructions."
),
)
parser.add_argument(
"--checkpoints_total_limit",
type=int,
default=None,
help=("Max number of checkpoints to store."),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-6,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--dataloader_num_workers",
type=int,
default=0,
help=(
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
),
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="train_controlnet",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)
args = parser.parse_args()
return args
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
project_config=accelerator_project_config,
)
# Disable AMP for MPS.
if torch.backends.mps.is_available():
accelerator.native_amp = False
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Handle the repository creation
if accelerator.is_main_process:
if args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# import correct text encoder class
# text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
from transformers import CLIPTextModel, CLIPTokenizer
pretrained_model_name_or_path = "masterful/gligen-1-4-generation-text-box"
tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder="tokenizer")
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder="text_encoder")
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")
# Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files)
def unwrap_model(model):
model = accelerator.unwrap_model(model)
model = model._orig_mod if is_compiled_module(model) else model
return model
# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
def save_model_hook(models, weights, output_dir):
if accelerator.is_main_process:
i = len(weights) - 1
while len(weights) > 0:
weights.pop()
model = models[i]
sub_dir = "unet"
model.save_pretrained(os.path.join(output_dir, sub_dir))
i -= 1
def load_model_hook(models, input_dir):
while len(models) > 0:
# pop models so that they are not loaded again
model = models.pop()
# load diffusers style into model
load_model = unet.from_pretrained(input_dir, subfolder="unet")
model.register_to_config(**load_model.config)
model.load_state_dict(load_model.state_dict())
del load_model
accelerator.register_save_state_pre_hook(save_model_hook)
accelerator.register_load_state_pre_hook(load_model_hook)
vae.requires_grad_(False)
unet.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
import xformers
xformers_version = version.parse(xformers.__version__)
if xformers_version == version.parse("0.0.16"):
logger.warning(
"xFormers 0.0.16 cannot be used for training in some GPUs. If you observe problems during training, please update xFormers to at least 0.0.17. See https://huggingface.co/docs/diffusers/main/en/optimization/xformers for more details."
)
unet.enable_xformers_memory_efficient_attention()
# controlnet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# if args.gradient_checkpointing:
# controlnet.enable_gradient_checkpointing()
# Check that all trainable models are in full precision
low_precision_error_string = (
" Please make sure to always have all model weights in full float32 precision when starting training - even if"
" doing mixed precision training, copy of the weights should still be float32."
)
if unwrap_model(unet).dtype != torch.float32:
raise ValueError(f"Controlnet loaded as datatype {unwrap_model(unet).dtype}. {low_precision_error_string}")
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
optimizer_class = torch.optim.AdamW
# Optimizer creation
for n, m in unet.named_modules():
if ("fuser" in n) or ("position_net" in n):
import torch.nn as nn
if isinstance(m, (nn.Linear, nn.LayerNorm)):
m.reset_parameters()
params_to_optimize = []
for n, p in unet.named_parameters():
if ("fuser" in n) or ("position_net" in n):
p.requires_grad = True
params_to_optimize.append(p)
optimizer = optimizer_class(
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
from dataset import COCODataset
train_dataset = COCODataset(
data_path=args.data_path,
image_path=args.image_path,
tokenizer=tokenizer,
image_size=args.resolution,
max_boxes_per_data=30,
)
print("num samples: ", len(train_dataset))
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
shuffle=True,
# collate_fn=collate_fn,
batch_size=args.train_batch_size,
num_workers=args.dataloader_num_workers,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes,
num_training_steps=args.max_train_steps * accelerator.num_processes,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
unet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet, optimizer, train_dataloader, lr_scheduler
)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move vae, unet and text_encoder to device and cast to weight_dtype
vae.to(accelerator.device, dtype=weight_dtype)
# unet.to(accelerator.device, dtype=weight_dtype)
unet.to(accelerator.device, dtype=torch.float32)
text_encoder.to(accelerator.device, dtype=weight_dtype)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
tracker_config = dict(vars(args))
# tensorboard cannot handle list types for config
# tracker_config.pop("validation_prompt")
# tracker_config.pop("validation_image")
accelerator.init_trackers(args.tracker_project_name, config=tracker_config)
# Train!
# total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
# logger.info("***** Running training *****")
# logger.info(f" Num examples = {len(train_dataset)}")
# logger.info(f" Num batches each epoch = {len(train_dataloader)}")
# logger.info(f" Num Epochs = {args.num_train_epochs}")
# logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
# logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
# logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
# logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the most recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1] if len(dirs) > 0 else None
if path is None:
accelerator.print(
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
)
args.resume_from_checkpoint = None
initial_global_step = 0
else:
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
initial_global_step = global_step
first_epoch = global_step // num_update_steps_per_epoch
else:
initial_global_step = 0
progress_bar = tqdm(
range(0, args.max_train_steps),
initial=initial_global_step,
desc="Steps",
# Only show the progress bar once on each machine.
disable=not accelerator.is_local_main_process,
)
log_validation(
vae,
text_encoder,
tokenizer,
unet,
noise_scheduler,
args,
accelerator,
global_step,
weight_dtype,
)
# image_logs = None
for epoch in range(first_epoch, args.num_train_epochs):
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * vae.config.scaling_factor
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
with torch.no_grad():
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(
batch["caption"]["input_ids"].squeeze(1),
# batch['caption']['attention_mask'].squeeze(1),
return_dict=False,
)[0]
cross_attention_kwargs = {}
cross_attention_kwargs["gligen"] = {
"boxes": batch["boxes"],
"positive_embeddings": batch["text_embeddings_before_projection"],
"masks": batch["masks"],
}
# Predict the noise residual
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=encoder_hidden_states,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
accelerator.clip_grad_norm_(params_to_optimize, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
checkpoints = os.listdir(args.output_dir)
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
if len(checkpoints) >= args.checkpoints_total_limit:
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
removing_checkpoints = checkpoints[0:num_to_remove]
logger.info(
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
)
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
for removing_checkpoint in removing_checkpoints:
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
shutil.rmtree(removing_checkpoint)
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step:06d}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
# if args.validation_prompt is not None and global_step % args.validation_steps == 0:
log_validation(
vae,
text_encoder,
tokenizer,
unet,
noise_scheduler,
args,
accelerator,
global_step,
weight_dtype,
)
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
# Create the pipeline using using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unwrap_model(unet)
unet.save_pretrained(args.output_dir)
#
# # Run a final round of validation.
# image_logs = None
# if args.validation_prompt is not None:
# image_logs = log_validation(
# vae=vae,
# text_encoder=text_encoder,
# tokenizer=tokenizer,
# unet=unet,
# controlnet=None,
# args=args,
# accelerator=accelerator,
# weight_dtype=weight_dtype,
# step=global_step,
# is_final_validation=True,
# )
#
# if args.push_to_hub:
# save_model_card(
# repo_id,
# image_logs=image_logs,
# base_model=args.pretrained_model_name_or_path,
# repo_folder=args.output_dir,
# )
# upload_folder(
# repo_id=repo_id,
# folder_path=args.output_dir,
# commit_message="End of training",
# ignore_patterns=["step_*", "epoch_*"],
# )
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)