diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md new file mode 100644 index 0000000000..5d6f3b9eb9 --- /dev/null +++ b/examples/dreambooth/README.md @@ -0,0 +1,82 @@ +## DreamBooth training example + +[DreamBooth](https://arxiv.org/abs/2208.12242) is a method to personalize text2image models like stable diffusion given just a few(3~5) images of a subject. +The `train_dreambooth.py` script shows how to implement the training procedure and adapt it for stable diffusion. + +## Running locally +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +```bash +pip install diffusers[training] accelerate transformers +``` + +And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +### Dog toy example + +You need to accept the model license before downloading or using the weights. In this example we'll use model version `v1-4`, so you'll need to visit [its card](https://huggingface.co/CompVis/stable-diffusion-v1-4), read the license and tick the checkbox if you agree. + +You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to [this section of the documentation](https://huggingface.co/docs/hub/security-tokens). + +Run the following command to authenticate your token + +```bash +huggingface-cli login +``` + +If you have already cloned the repo, then you won't need to go through these steps. You can simple remove the `--use_auth_token` arg from the following command. + +
+ +Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data. + +And launch the training using + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" + +python train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --resolution=512 \ + --train_batch_size=4 \ + --gradient_accumulation_steps=1 \ + --learning_rate=1e-5 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --output_dir="dreambooth_dog" \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" + --num_class_images=1000 \ + --max_train_steps=3000 +``` + + +### Inference + +Once you have trained a model using above command, the inference can be done simply using the `StableDiffusionPipeline`. Make sure to include the `identifier`(e.g. sks in above example) in your prompt. + +```python + +from torch import autocast +from diffusers import StableDiffusionPipeline + +model_id = "path-to-your-trained-model" +pipe = StableDiffusionPipeline.from_pretrained(model_id,torch_dtype=torch.float16).to("cuda") + +prompt = "A photo of sks dog in a bucket" + +with autocast("cuda"): + image = pipe(prompt, num_inference_steps=50, guidance_scale=7.5).images[0] + +image.save("dog-bucket.png") +``` diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py new file mode 100644 index 0000000000..4363b1c358 --- /dev/null +++ b/examples/dreambooth/train_dreambooth.py @@ -0,0 +1,543 @@ +import argparse +import math +import os +from pathlib import Path +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.utils.data import Dataset + +import PIL +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import set_seed +from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers.optimization import get_scheduler +from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker +from huggingface_hub import HfFolder, Repository, whoami +from PIL import Image +from tqdm.auto import tqdm +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + + +logger = get_logger(__name__) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--tokenizer_name", + type=str, + default=None, + help="Pretrained tokenizer name or path if not the same as model_name", + ) + parser.add_argument( + "--instance_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of instance images.", + ) + parser.add_argument( + "--class_data_dir", + type=str, + default=None, + required=True, + help="A folder containing the training data of class images.", + ) + parser.add_argument( + "--instance_prompt", + type=str, + default=None, + help="The prompt with identifier specifing the instance", + ) + parser.add_argument( + "--class_prompt", + type=str, + default=None, + help="The prompt to specify images in the same class as provided intance images.", + ) + parser.add_argument( + "--without_prior_preservation", + default=False, + action="store_true", + help="Flag to remove prior perservation loss.", + ) + parser.add_argument( + "--num_class_images", + type=int, + default=100, + help=( + "Minimal class images for prior perversation loss. If not have enough images, additional images will be" + " sampled with class_prompt." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="text-inversion-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-5, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument( + "--use_auth_token", + action="store_true", + help=( + "Will use the token generated when running `huggingface-cli login` (necessary to use this script with" + " private models)." + ), + ) + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default="no", + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose" + "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." + "and an Nvidia Ampere GPU." + ), + ) + parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") + + args = parser.parse_args() + env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) + if env_local_rank != -1 and env_local_rank != args.local_rank: + args.local_rank = env_local_rank + + if args.instance_data_dir is None: + raise ValueError("You must specify a train data directory.") + + return args + + +class DreamBoothDataset(Dataset): + def __init__( + self, + instance_data_root, + instance_prompt, + tokenizer, + class_data_root=None, + class_prompt=None, + size=512, + interpolation="bicubic", + center_crop=False, + ): + self.size = size + self.center_crop = center_crop + self.tokenizer = tokenizer + + self.instance_data_root = Path(instance_data_root) + assert self.instance_data_root.exists(), "Instance images root doesn't exists." + self.instance_images_path = list(Path(instance_data_root).iterdir()) + self.num_instance_images = len(self.instance_images_path) + self.instance_prompt = instance_prompt + self._length = self.num_instance_images + + if class_data_root is not None: + self.class_data_root = Path(class_data_root) + assert self.class_data_root.exists(), "Class images root doesn't exists." + self.class_images_path = list(Path(class_data_root).iterdir()) + self.num_class_images = len(self.class_images_path) + self._length = max(self.num_class_images, self.num_instance_images) + self.class_prompt = class_prompt + else: + self.class_data_root = None + + self.interpolation = { + "linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + + def __len__(self): + return self._length + + def transform_image(self, image: Image): + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = ( + img.shape[0], + img.shape[1], + ) + img = img[(h - crop) // 2 : (h + crop) // 2, (w - crop) // 2 : (w + crop) // 2] + + image = Image.fromarray(img) + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = np.array(image).astype(np.uint8) + image = (image / 127.5 - 1.0).astype(np.float32) + image = torch.from_numpy(image).permute(2, 0, 1) + return image + + def __getitem__(self, index): + example = {} + instance_image = Image.open(self.instance_images_path[index % self.num_instance_images]) + if not instance_image.mode == "RGB": + instance_image = instance_image.convert("RGB") + example["instance_images"] = self.transform_image(instance_image) + example["instance_prompt_ids"] = self.tokenizer( + self.instance_prompt, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + if self.class_data_root: + class_image = Image.open(self.class_images_path[index % self.num_class_images]) + if not class_image.mode == "RGB": + class_image = class_image.convert("RGB") + example["class_images"] = self.transform_image(class_image) + example["class_prompt_ids"] = self.tokenizer( + self.class_prompt, + padding="max_length", + truncation=True, + max_length=self.tokenizer.model_max_length, + return_tensors="pt", + ).input_ids[0] + + return example + + +def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None): + if token is None: + token = HfFolder.get_token() + if organization is None: + username = whoami(token)["name"] + return f"{username}/{model_id}" + else: + return f"{organization}/{model_id}" + + +def main(): + args = parse_args() + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with="tensorboard", + logging_dir=logging_dir, + ) + + if args.seed is not None: + set_seed(args.seed) + + if not args.without_prior_preservation: + class_images_dir = Path(args.class_data_dir) + if not class_images_dir.exists(): + class_images_dir.mkdir(parents=True) + cur_class_images = len(list(class_images_dir.iterdir())) + + if cur_class_images < args.num_class_images: + sd_model = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, use_auth_token=args.use_auth_token + ) + sd_model = accelerator.prepare(sd_model) + sd_model.to(accelerator.device) + + num_new_images = args.num_class_images - cur_class_images + logger.info(f"Number of class images to sample: {num_new_images}.") + total_prompts = [args.class_prompt] * num_new_images + batch_prompts = [ + total_prompts[x : x + args.sample_batch_size] for x in range(0, num_new_images, args.sample_batch_size) + ] + + img_id = cur_class_images + for text in tqdm( + batch_prompts, desc="Generating class images", disable=not accelerator.is_local_main_process + ): + with torch.no_grad(): + images = sd_model(text, height=512, width=512, num_inference_steps=50)["sample"] + + for image in images: + image.save(class_images_dir / f"{img_id}.jpg") + img_id += 1 + del sd_model + + # Handle the repository creation + if accelerator.is_main_process: + if args.push_to_hub: + if args.hub_model_id is None: + repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token) + else: + repo_name = args.hub_model_id + repo = Repository(args.output_dir, clone_from=repo_name) + + with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore: + if "step_*" not in gitignore: + gitignore.write("step_*\n") + if "epoch_*" not in gitignore: + gitignore.write("epoch_*\n") + elif args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + # Load the tokenizer and add the placeholder token as a additional special token + if args.tokenizer_name: + tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) + elif args.pretrained_model_name_or_path: + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, subfolder="tokenizer", use_auth_token=args.use_auth_token + ) + + # Load models and create wrapper for stable diffusion + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", use_auth_token=args.use_auth_token + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, subfolder="vae", use_auth_token=args.use_auth_token + ) + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="unet", use_auth_token=args.use_auth_token + ) + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + optimizer = torch.optim.AdamW( + unet.parameters(), # only optimize unet + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + noise_scheduler = DDPMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + ) + + train_dataset = DreamBoothDataset( + instance_data_root=args.instance_data_dir, + instance_prompt=args.instance_prompt, + class_data_root=args.class_data_dir if not args.without_prior_preservation else None, + class_prompt=args.class_prompt, + tokenizer=tokenizer, + size=args.resolution, + center_crop=args.center_crop, + ) + train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=args.train_batch_size, shuffle=True) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) + + text_encoder, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + text_encoder, optimizer, train_dataloader, lr_scheduler + ) + + # Move vae and unet to device + vae.to(accelerator.device) + unet.to(accelerator.device) + + # Keep vae in eval model as we don't train it + vae.eval() + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + accelerator.init_trackers("dreambooth", config=vars(args)) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + # Only show the progress bar once on each machine. + progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + progress_bar.set_description("Steps") + global_step = 0 + + for epoch in range(args.num_train_epochs): + text_encoder.train() + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(text_encoder): + # Convert images to latent space + if not args.without_prior_preservation: + images = torch.cat([batch["instance_images"], batch["class_images"]], dim=0) + input_ids = torch.cat([batch["instance_prompt_ids"], batch["class_prompt_ids"]], dim=0) + else: + images = batch["instance_images"] + input_ids = batch["instance_prompt_ids"] + + latents = vae.encode(images).latent_dist.sample().detach() + latents = latents * 0.18215 + + # Sample noise that we'll add to the latents + noise = torch.randn(latents.shape).to(latents.device) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint( + 0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ).long() + + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(input_ids)[0] + + # Predict the noise residual + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + accelerator.backward(loss) + + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + accelerator.wait_for_everyone() + + # Create the pipeline using using the trained modules and save it. + if accelerator.is_main_process: + pipeline = StableDiffusionPipeline( + text_encoder=accelerator.unwrap_model(text_encoder), + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_dir) + + if args.push_to_hub: + repo.push_to_hub( + args, pipeline, repo, commit_message="End of training", blocking=False, auto_lfs_prune=True + ) + + accelerator.end_training() + + +if __name__ == "__main__": + main()