mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[dreambooth lora sdxl] add sdxl micro conditioning (#6795)
* add micro conditioning * remove redundant lines * style * fix missing 's' * fix missing shape bug due to missing RGB if statement * remove redundant if, change arg order --------- Co-authored-by: Sayak Paul <spsayakpaul@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import itertools
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
@@ -40,6 +41,7 @@ from PIL import Image
|
||||
from PIL.ImageOps import exif_transpose
|
||||
from torch.utils.data import Dataset
|
||||
from torchvision import transforms
|
||||
from torchvision.transforms.functional import crop
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
@@ -304,18 +306,6 @@ def parse_args(input_args=None):
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_h",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crops_coords_top_left_w",
|
||||
type=int,
|
||||
default=0,
|
||||
help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop",
|
||||
default=False,
|
||||
@@ -325,6 +315,11 @@ def parse_args(input_args=None):
|
||||
" cropped. The images will be resized to the resolution first before cropping."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random_flip",
|
||||
action="store_true",
|
||||
help="whether to randomly flip images horizontally",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_text_encoder",
|
||||
action="store_true",
|
||||
@@ -669,6 +664,41 @@ class DreamBoothDataset(Dataset):
|
||||
self.instance_images = []
|
||||
for img in instance_images:
|
||||
self.instance_images.extend(itertools.repeat(img, repeats))
|
||||
|
||||
# image processing to prepare for using SD-XL micro-conditioning
|
||||
self.original_sizes = []
|
||||
self.crop_top_lefts = []
|
||||
self.pixel_values = []
|
||||
train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR)
|
||||
train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size)
|
||||
train_flip = transforms.RandomHorizontalFlip(p=1.0)
|
||||
train_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
for image in self.instance_images:
|
||||
image = exif_transpose(image)
|
||||
if not image.mode == "RGB":
|
||||
image = image.convert("RGB")
|
||||
self.original_sizes.append((image.height, image.width))
|
||||
image = train_resize(image)
|
||||
if args.random_flip and random.random() < 0.5:
|
||||
# flip
|
||||
image = train_flip(image)
|
||||
if args.center_crop:
|
||||
y1 = max(0, int(round((image.height - args.resolution) / 2.0)))
|
||||
x1 = max(0, int(round((image.width - args.resolution) / 2.0)))
|
||||
image = train_crop(image)
|
||||
else:
|
||||
y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution))
|
||||
image = crop(image, y1, x1, h, w)
|
||||
crop_top_left = (y1, x1)
|
||||
self.crop_top_lefts.append(crop_top_left)
|
||||
image = train_transforms(image)
|
||||
self.pixel_values.append(image)
|
||||
|
||||
self.num_instance_images = len(self.instance_images)
|
||||
self._length = self.num_instance_images
|
||||
|
||||
@@ -698,12 +728,12 @@ class DreamBoothDataset(Dataset):
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = self.instance_images[index % self.num_instance_images]
|
||||
instance_image = exif_transpose(instance_image)
|
||||
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
instance_image = self.pixel_values[index % self.num_instance_images]
|
||||
original_size = self.original_sizes[index % self.num_instance_images]
|
||||
crop_top_left = self.crop_top_lefts[index % self.num_instance_images]
|
||||
example["instance_images"] = instance_image
|
||||
example["original_size"] = original_size
|
||||
example["crop_top_left"] = crop_top_left
|
||||
|
||||
if self.custom_instance_prompts:
|
||||
caption = self.custom_instance_prompts[index % self.num_instance_images]
|
||||
@@ -730,6 +760,8 @@ class DreamBoothDataset(Dataset):
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
prompts = [example["instance_prompt"] for example in examples]
|
||||
original_sizes = [example["original_size"] for example in examples]
|
||||
crop_top_lefts = [example["crop_top_left"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
@@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False):
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
batch = {"pixel_values": pixel_values, "prompts": prompts}
|
||||
batch = {
|
||||
"pixel_values": pixel_values,
|
||||
"prompts": prompts,
|
||||
"original_sizes": original_sizes,
|
||||
"crop_top_lefts": crop_top_lefts,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
@@ -1233,11 +1270,9 @@ def main(args):
|
||||
# pooled text embeddings
|
||||
# time ids
|
||||
|
||||
def compute_time_ids():
|
||||
def compute_time_ids(original_size, crops_coords_top_left):
|
||||
# Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids
|
||||
original_size = (args.resolution, args.resolution)
|
||||
target_size = (args.resolution, args.resolution)
|
||||
crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w)
|
||||
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
||||
add_time_ids = torch.tensor([add_time_ids])
|
||||
add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype)
|
||||
@@ -1254,9 +1289,6 @@ def main(args):
|
||||
pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device)
|
||||
return prompt_embeds, pooled_prompt_embeds
|
||||
|
||||
# Handle instance prompt.
|
||||
instance_time_ids = compute_time_ids()
|
||||
|
||||
# If no type of tuning is done on the text_encoder and custom instance prompts are NOT
|
||||
# provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid
|
||||
# the redundant encoding.
|
||||
@@ -1267,7 +1299,6 @@ def main(args):
|
||||
|
||||
# Handle class prompt for prior-preservation.
|
||||
if args.with_prior_preservation:
|
||||
class_time_ids = compute_time_ids()
|
||||
if not args.train_text_encoder:
|
||||
class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings(
|
||||
args.class_prompt, text_encoders, tokenizers
|
||||
@@ -1282,9 +1313,6 @@ def main(args):
|
||||
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
|
||||
# pack the statically computed variables appropriately here. This is so that we don't
|
||||
# have to pass them to the dataloader.
|
||||
add_time_ids = instance_time_ids
|
||||
if args.with_prior_preservation:
|
||||
add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0)
|
||||
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
if not args.train_text_encoder:
|
||||
@@ -1436,18 +1464,24 @@ def main(args):
|
||||
# (this is the forward diffusion process)
|
||||
noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps)
|
||||
|
||||
# time ids
|
||||
add_time_ids = torch.cat(
|
||||
[
|
||||
compute_time_ids(original_size=s, crops_coords_top_left=c)
|
||||
for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"])
|
||||
]
|
||||
)
|
||||
|
||||
# Calculate the elements to repeat depending on the use of prior-preservation and custom captions.
|
||||
if not train_dataset.custom_instance_prompts:
|
||||
elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
else:
|
||||
elems_to_repeat_text_embeds = 1
|
||||
elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz
|
||||
|
||||
# Predict the noise residual
|
||||
if not args.train_text_encoder:
|
||||
unet_added_conditions = {
|
||||
"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1),
|
||||
"time_ids": add_time_ids,
|
||||
"text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1),
|
||||
}
|
||||
prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1)
|
||||
@@ -1459,7 +1493,7 @@ def main(args):
|
||||
return_dict=False,
|
||||
)[0]
|
||||
else:
|
||||
unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)}
|
||||
unet_added_conditions = {"time_ids": add_time_ids}
|
||||
prompt_embeds, pooled_prompt_embeds = encode_prompt(
|
||||
text_encoders=[text_encoder_one, text_encoder_two],
|
||||
tokenizers=None,
|
||||
|
||||
Reference in New Issue
Block a user