diff --git a/examples/dreambooth/train_dreambooth_lora_sdxl.py b/examples/dreambooth/train_dreambooth_lora_sdxl.py index 2cc2ab79db..aa09bf9a0e 100644 --- a/examples/dreambooth/train_dreambooth_lora_sdxl.py +++ b/examples/dreambooth/train_dreambooth_lora_sdxl.py @@ -19,6 +19,7 @@ import itertools import logging import math import os +import random import shutil import warnings from pathlib import Path @@ -40,6 +41,7 @@ from PIL import Image from PIL.ImageOps import exif_transpose from torch.utils.data import Dataset from torchvision import transforms +from torchvision.transforms.functional import crop from tqdm.auto import tqdm from transformers import AutoTokenizer, PretrainedConfig @@ -304,18 +306,6 @@ def parse_args(input_args=None): " resolution" ), ) - parser.add_argument( - "--crops_coords_top_left_h", - type=int, - default=0, - help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), - ) - parser.add_argument( - "--crops_coords_top_left_w", - type=int, - default=0, - help=("Coordinate for (the height) to be included in the crop coordinate embeddings needed by SDXL UNet."), - ) parser.add_argument( "--center_crop", default=False, @@ -325,6 +315,11 @@ def parse_args(input_args=None): " cropped. The images will be resized to the resolution first before cropping." ), ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) parser.add_argument( "--train_text_encoder", action="store_true", @@ -669,6 +664,41 @@ class DreamBoothDataset(Dataset): self.instance_images = [] for img in instance_images: self.instance_images.extend(itertools.repeat(img, repeats)) + + # image processing to prepare for using SD-XL micro-conditioning + self.original_sizes = [] + self.crop_top_lefts = [] + self.pixel_values = [] + train_resize = transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR) + train_crop = transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size) + train_flip = transforms.RandomHorizontalFlip(p=1.0) + train_transforms = transforms.Compose( + [ + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + for image in self.instance_images: + image = exif_transpose(image) + if not image.mode == "RGB": + image = image.convert("RGB") + self.original_sizes.append((image.height, image.width)) + image = train_resize(image) + if args.random_flip and random.random() < 0.5: + # flip + image = train_flip(image) + if args.center_crop: + y1 = max(0, int(round((image.height - args.resolution) / 2.0))) + x1 = max(0, int(round((image.width - args.resolution) / 2.0))) + image = train_crop(image) + else: + y1, x1, h, w = train_crop.get_params(image, (args.resolution, args.resolution)) + image = crop(image, y1, x1, h, w) + crop_top_left = (y1, x1) + self.crop_top_lefts.append(crop_top_left) + image = train_transforms(image) + self.pixel_values.append(image) + self.num_instance_images = len(self.instance_images) self._length = self.num_instance_images @@ -698,12 +728,12 @@ class DreamBoothDataset(Dataset): def __getitem__(self, index): example = {} - instance_image = self.instance_images[index % self.num_instance_images] - instance_image = exif_transpose(instance_image) - - if not instance_image.mode == "RGB": - instance_image = instance_image.convert("RGB") - example["instance_images"] = self.image_transforms(instance_image) + instance_image = self.pixel_values[index % self.num_instance_images] + original_size = self.original_sizes[index % self.num_instance_images] + crop_top_left = self.crop_top_lefts[index % self.num_instance_images] + example["instance_images"] = instance_image + example["original_size"] = original_size + example["crop_top_left"] = crop_top_left if self.custom_instance_prompts: caption = self.custom_instance_prompts[index % self.num_instance_images] @@ -730,6 +760,8 @@ class DreamBoothDataset(Dataset): def collate_fn(examples, with_prior_preservation=False): pixel_values = [example["instance_images"] for example in examples] prompts = [example["instance_prompt"] for example in examples] + original_sizes = [example["original_size"] for example in examples] + crop_top_lefts = [example["crop_top_left"] for example in examples] # Concat class and instance examples for prior preservation. # We do this to avoid doing two forward passes. @@ -740,7 +772,12 @@ def collate_fn(examples, with_prior_preservation=False): pixel_values = torch.stack(pixel_values) pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() - batch = {"pixel_values": pixel_values, "prompts": prompts} + batch = { + "pixel_values": pixel_values, + "prompts": prompts, + "original_sizes": original_sizes, + "crop_top_lefts": crop_top_lefts, + } return batch @@ -1233,11 +1270,9 @@ def main(args): # pooled text embeddings # time ids - def compute_time_ids(): + def compute_time_ids(original_size, crops_coords_top_left): # Adapted from pipeline.StableDiffusionXLPipeline._get_add_time_ids - original_size = (args.resolution, args.resolution) target_size = (args.resolution, args.resolution) - crops_coords_top_left = (args.crops_coords_top_left_h, args.crops_coords_top_left_w) add_time_ids = list(original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids]) add_time_ids = add_time_ids.to(accelerator.device, dtype=weight_dtype) @@ -1254,9 +1289,6 @@ def main(args): pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) return prompt_embeds, pooled_prompt_embeds - # Handle instance prompt. - instance_time_ids = compute_time_ids() - # If no type of tuning is done on the text_encoder and custom instance prompts are NOT # provided (i.e. the --instance_prompt is used for all images), we encode the instance prompt once to avoid # the redundant encoding. @@ -1267,7 +1299,6 @@ def main(args): # Handle class prompt for prior-preservation. if args.with_prior_preservation: - class_time_ids = compute_time_ids() if not args.train_text_encoder: class_prompt_hidden_states, class_pooled_prompt_embeds = compute_text_embeddings( args.class_prompt, text_encoders, tokenizers @@ -1282,9 +1313,6 @@ def main(args): # If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images), # pack the statically computed variables appropriately here. This is so that we don't # have to pass them to the dataloader. - add_time_ids = instance_time_ids - if args.with_prior_preservation: - add_time_ids = torch.cat([add_time_ids, class_time_ids], dim=0) if not train_dataset.custom_instance_prompts: if not args.train_text_encoder: @@ -1436,18 +1464,24 @@ def main(args): # (this is the forward diffusion process) noisy_model_input = noise_scheduler.add_noise(model_input, noise, timesteps) + # time ids + add_time_ids = torch.cat( + [ + compute_time_ids(original_size=s, crops_coords_top_left=c) + for s, c in zip(batch["original_sizes"], batch["crop_top_lefts"]) + ] + ) + # Calculate the elements to repeat depending on the use of prior-preservation and custom captions. if not train_dataset.custom_instance_prompts: elems_to_repeat_text_embeds = bsz // 2 if args.with_prior_preservation else bsz - elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz else: elems_to_repeat_text_embeds = 1 - elems_to_repeat_time_ids = bsz // 2 if args.with_prior_preservation else bsz # Predict the noise residual if not args.train_text_encoder: unet_added_conditions = { - "time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1), + "time_ids": add_time_ids, "text_embeds": unet_add_text_embeds.repeat(elems_to_repeat_text_embeds, 1), } prompt_embeds_input = prompt_embeds.repeat(elems_to_repeat_text_embeds, 1, 1) @@ -1459,7 +1493,7 @@ def main(args): return_dict=False, )[0] else: - unet_added_conditions = {"time_ids": add_time_ids.repeat(elems_to_repeat_time_ids, 1)} + unet_added_conditions = {"time_ids": add_time_ids} prompt_embeds, pooled_prompt_embeds = encode_prompt( text_encoders=[text_encoder_one, text_encoder_two], tokenizers=None,