From adf1f911f0bb243c166d75c6b6d50ed624b6a3d2 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 11 Sep 2024 06:50:02 +0530 Subject: [PATCH 1/6] [Tests] fix some fast gpu tests. (#9379) fix some fast gpu tests. --- examples/dreambooth/train_dreambooth_lora_flux.py | 2 ++ src/diffusers/models/transformers/transformer_flux.py | 1 + tests/pipelines/flux/test_pipeline_flux_img2img.py | 2 +- tests/pipelines/flux/test_pipeline_flux_inpaint.py | 2 +- 4 files changed, 5 insertions(+), 2 deletions(-) diff --git a/examples/dreambooth/train_dreambooth_lora_flux.py b/examples/dreambooth/train_dreambooth_lora_flux.py index 8b4bf989e8..48d669418f 100644 --- a/examples/dreambooth/train_dreambooth_lora_flux.py +++ b/examples/dreambooth/train_dreambooth_lora_flux.py @@ -1597,6 +1597,7 @@ def main(args): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=prompts, ) else: @@ -1606,6 +1607,7 @@ def main(args): tokenizers=[None, None], text_input_ids_list=[tokens_one, tokens_two], max_sequence_length=args.max_sequence_length, + device=accelerator.device, prompt=args.instance_prompt, ) diff --git a/src/diffusers/models/transformers/transformer_flux.py b/src/diffusers/models/transformers/transformer_flux.py index fd0881a148..e38efe668c 100644 --- a/src/diffusers/models/transformers/transformer_flux.py +++ b/src/diffusers/models/transformers/transformer_flux.py @@ -465,6 +465,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] + ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) diff --git a/tests/pipelines/flux/test_pipeline_flux_img2img.py b/tests/pipelines/flux/test_pipeline_flux_img2img.py index ec89f05382..a038b17258 100644 --- a/tests/pipelines/flux/test_pipeline_flux_img2img.py +++ b/tests/pipelines/flux/test_pipeline_flux_img2img.py @@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxImg2ImgPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxImg2ImgPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) diff --git a/tests/pipelines/flux/test_pipeline_flux_inpaint.py b/tests/pipelines/flux/test_pipeline_flux_inpaint.py index 7ad77cb6ea..ac2eb1fa26 100644 --- a/tests/pipelines/flux/test_pipeline_flux_inpaint.py +++ b/tests/pipelines/flux/test_pipeline_flux_inpaint.py @@ -18,11 +18,11 @@ from ..test_pipelines_common import PipelineTesterMixin enable_full_determinism() -@unittest.skipIf(torch_device == "mps", "Flux has a float64 operation which is not supported in MPS.") class FluxInpaintPipelineFastTests(unittest.TestCase, PipelineTesterMixin): pipeline_class = FluxInpaintPipeline params = frozenset(["prompt", "height", "width", "guidance_scale", "prompt_embeds", "pooled_prompt_embeds"]) batch_params = frozenset(["prompt"]) + test_xformers_attention = False def get_dummy_components(self): torch.manual_seed(0) From c002731d930bd3fe893f36841d241e3b86bc22e6 Mon Sep 17 00:00:00 2001 From: Yu Zheng Date: Wed, 11 Sep 2024 09:34:37 +0800 Subject: [PATCH 2/6] [examples] add controlnet sd3 example (#9249) * add controlnet sd3 example * add controlnet sd3 example * update controlnet sd3 example * add controlnet sd3 example test * fix quality and style * update test * update test --------- Co-authored-by: Sayak Paul --- examples/controlnet/README_sd3.md | 152 ++ examples/controlnet/requirements_sd3.txt | 8 + examples/controlnet/test_controlnet.py | 21 + examples/controlnet/train_controlnet_sd3.py | 1415 +++++++++++++++++++ 4 files changed, 1596 insertions(+) create mode 100644 examples/controlnet/README_sd3.md create mode 100644 examples/controlnet/requirements_sd3.txt create mode 100644 examples/controlnet/train_controlnet_sd3.py diff --git a/examples/controlnet/README_sd3.md b/examples/controlnet/README_sd3.md new file mode 100644 index 0000000000..1788e07a21 --- /dev/null +++ b/examples/controlnet/README_sd3.md @@ -0,0 +1,152 @@ +# ControlNet training example for Stable Diffusion 3 (SD3) + +The `train_controlnet_sd3.py` script shows how to implement the ControlNet training procedure and adapt it for [Stable Diffusion 3](https://arxiv.org/abs/2403.03206). + +## Running locally with PyTorch + +### Installing the dependencies + +Before running the scripts, make sure to install the library's training dependencies: + +**Important** + +To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment: + +```bash +git clone https://github.com/huggingface/diffusers +cd diffusers +pip install -e . +``` + +Then cd in the `examples/controlnet` folder and run +```bash +pip install -r requirements_sd3.txt +``` + +And initialize an [πŸ€—Accelerate](https://github.com/huggingface/accelerate/) environment with: + +```bash +accelerate config +``` + +Or for a default accelerate configuration without answering questions about your environment + +```bash +accelerate config default +``` + +Or if your environment doesn't support an interactive shell (e.g., a notebook) + +```python +from accelerate.utils import write_basic_config +write_basic_config() +``` + +When running `accelerate config`, if we specify torch compile mode to True there can be dramatic speedups. + +## Circle filling dataset + +The original dataset is hosted in the [ControlNet repo](https://huggingface.co/lllyasviel/ControlNet/blob/main/training/fill50k.zip). We re-uploaded it to be compatible with `datasets` [here](https://huggingface.co/datasets/fusing/fill50k). Note that `datasets` handles dataloading within the training script. +Please download the dataset and unzip it in the directory `fill50k` in the `examples/controlnet` folder. + +## Training + +First download the SD3 model from [Hugging Face Hub](https://huggingface.co/stabilityai/stable-diffusion-3-medium). We will use it as a base model for the ControlNet training. +> [!NOTE] +> As the model is gated, before using it with diffusers you first need to go to the [Stable Diffusion 3 Medium Hugging Face page](https://huggingface.co/stabilityai/stable-diffusion-3-medium-diffusers), fill in the form and accept the gate. Once you are in, you need to log in so that your system knows you’ve accepted the gate. Use the command below to log in: + +```bash +huggingface-cli login +``` + +This will also allow us to push the trained model parameters to the Hugging Face Hub platform. + + +Our training examples use two test conditioning images. They can be downloaded by running + +```sh +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png + +wget https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_2.png +``` + +Then run the following commands to train a ControlNet model. + +```bash +export MODEL_DIR="stabilityai/stable-diffusion-3-medium-diffusers" +export OUTPUT_DIR="sd3-controlnet-out" + +accelerate launch train_controlnet_sd3.py \ + --pretrained_model_name_or_path=$MODEL_DIR \ + --output_dir=$OUTPUT_DIR \ + --train_data_dir="fill50k" \ + --resolution=1024 \ + --learning_rate=1e-5 \ + --max_train_steps=15000 \ + --validation_image "./conditioning_image_1.png" "./conditioning_image_2.png" \ + --validation_prompt "red circle with blue background" "cyan circle with brown floral background" \ + --validation_steps=100 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=4 +``` + +To better track our training experiments, we're using flags `validation_image`, `validation_prompt`, and `validation_steps` to allow the script to do a few validation inference runs. This allows us to qualitatively check if the training is progressing as expected. + +Our experiments were conducted on a single 40GB A100 GPU. + +### Inference + +Once training is done, we can perform inference like so: + +```python +from diffusers import StableDiffusion3ControlNetPipeline, SD3ControlNetModel +from diffusers.utils import load_image +import torch + +base_model_path = "stabilityai/stable-diffusion-3-medium-diffusers" +controlnet_path = "sd3-controlnet-out/checkpoint-6500/controlnet" + +controlnet = SD3ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16) +pipe = StableDiffusion3ControlNetPipeline.from_pretrained( + base_model_path, controlnet=controlnet +) +pipe.to("cuda", torch.float16) + + +control_image = load_image("./conditioning_image_1.png").resize((1024, 1024)) +prompt = "pale golden rod circle with old lace background" + +# generate image +generator = torch.manual_seed(0) +image = pipe( + prompt, num_inference_steps=20, generator=generator, control_image=control_image +).images[0] +image.save("./output.png") +``` + +## Notes + +### GPU usage + +SD3 is a large model and requires a lot of GPU memory. +We recommend using one GPU with at least 80GB of memory. +Make sure to use the right GPU when configuring the [accelerator](https://huggingface.co/docs/transformers/en/accelerate). + + +## Example results + +#### After 500 steps with batch size 8 + +| | | +|-------------------|:-------------------------:| +|| pale golden rod circle with old lace background | + ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-500.png) | + + +#### After 6500 steps with batch size 8: + +| | | +|-------------------|:-------------------------:| +|| pale golden rod circle with old lace background | + ![conditioning image](https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/controlnet_training/conditioning_image_1.png) | ![pale golden rod circle with old lace background](https://huggingface.co/datasets/DavyMorgan/sd3-controlnet-results/resolve/main/step-6500.png) | + diff --git a/examples/controlnet/requirements_sd3.txt b/examples/controlnet/requirements_sd3.txt new file mode 100644 index 0000000000..5ab6e9932e --- /dev/null +++ b/examples/controlnet/requirements_sd3.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +ftfy +tensorboard +Jinja2 +datasets +wandb diff --git a/examples/controlnet/test_controlnet.py b/examples/controlnet/test_controlnet.py index 8ed9a976cc..77b5614c7f 100644 --- a/examples/controlnet/test_controlnet.py +++ b/examples/controlnet/test_controlnet.py @@ -115,3 +115,24 @@ class ControlNetSDXL(ExamplesTestsAccelerate): run_command(self._launch_args + test_args) self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) + + +class ControlNetSD3(ExamplesTestsAccelerate): + def test_controlnet_sd3(self): + with tempfile.TemporaryDirectory() as tmpdir: + test_args = f""" + examples/controlnet/train_controlnet_sd3.py + --pretrained_model_name_or_path=DavyMorgan/tiny-sd3-pipe + --dataset_name=hf-internal-testing/fill10 + --output_dir={tmpdir} + --resolution=64 + --train_batch_size=1 + --gradient_accumulation_steps=1 + --controlnet_model_name_or_path=DavyMorgan/tiny-controlnet-sd3 + --max_train_steps=4 + --checkpointing_steps=2 + """.split() + + run_command(self._launch_args + test_args) + + self.assertTrue(os.path.isfile(os.path.join(tmpdir, "diffusion_pytorch_model.safetensors"))) diff --git a/examples/controlnet/train_controlnet_sd3.py b/examples/controlnet/train_controlnet_sd3.py new file mode 100644 index 0000000000..052eb9d4bf --- /dev/null +++ b/examples/controlnet/train_controlnet_sd3.py @@ -0,0 +1,1415 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import contextlib +import copy +import functools +import logging +import math +import os +import random +import shutil +from pathlib import Path + +import accelerate +import numpy as np +import torch +import torch.utils.checkpoint +import transformers +from accelerate import Accelerator +from accelerate.logging import get_logger +from accelerate.utils import ProjectConfiguration, set_seed +from datasets import load_dataset +from huggingface_hub import create_repo, upload_folder +from packaging import version +from PIL import Image +from torchvision import transforms +from tqdm.auto import tqdm +from transformers import CLIPTokenizer, PretrainedConfig, T5TokenizerFast + +import diffusers +from diffusers import ( + AutoencoderKL, + FlowMatchEulerDiscreteScheduler, + SD3ControlNetModel, + SD3Transformer2DModel, + StableDiffusion3ControlNetPipeline, +) +from diffusers.optimization import get_scheduler +from diffusers.training_utils import ( + clear_objs_and_retain_memory, + compute_density_for_timestep_sampling, + compute_loss_weighting_for_sd3, +) +from diffusers.utils import check_min_version, is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card +from diffusers.utils.torch_utils import is_compiled_module + + +if is_wandb_available(): + import wandb + +# Will error if the minimal version of diffusers is not installed. Remove at your own risks. +check_min_version("0.30.0.dev0") + +logger = get_logger(__name__) + + +def image_grid(imgs, rows, cols): + assert len(imgs) == rows * cols + + w, h = imgs[0].size + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste(img, box=(i % cols * w, i // cols * h)) + return grid + + +def log_validation(controlnet, args, accelerator, weight_dtype, step, is_final_validation=False): + logger.info("Running validation... ") + + if not is_final_validation: + controlnet = accelerator.unwrap_model(controlnet) + else: + controlnet = SD3ControlNetModel.from_pretrained(args.output_dir, torch_dtype=weight_dtype) + + pipeline = StableDiffusion3ControlNetPipeline.from_pretrained( + args.pretrained_model_name_or_path, + controlnet=controlnet, + safety_checker=None, + revision=args.revision, + variant=args.variant, + torch_dtype=weight_dtype, + ) + pipeline = pipeline.to(torch.device(accelerator.device)) + pipeline.set_progress_bar_config(disable=True) + + if args.seed is None: + generator = None + else: + generator = torch.manual_seed(args.seed) + + if len(args.validation_image) == len(args.validation_prompt): + validation_images = args.validation_image + validation_prompts = args.validation_prompt + elif len(args.validation_image) == 1: + validation_images = args.validation_image * len(args.validation_prompt) + validation_prompts = args.validation_prompt + elif len(args.validation_prompt) == 1: + validation_images = args.validation_image + validation_prompts = args.validation_prompt * len(args.validation_image) + else: + raise ValueError( + "number of `args.validation_image` and `args.validation_prompt` should be checked in `parse_args`" + ) + + image_logs = [] + inference_ctx = contextlib.nullcontext() if is_final_validation else torch.autocast(accelerator.device.type) + + for validation_prompt, validation_image in zip(validation_prompts, validation_images): + validation_image = Image.open(validation_image).convert("RGB") + + images = [] + + for _ in range(args.num_validation_images): + with inference_ctx: + image = pipeline( + validation_prompt, control_image=validation_image, num_inference_steps=20, generator=generator + ).images[0] + + images.append(image) + + image_logs.append( + {"validation_image": validation_image, "images": images, "validation_prompt": validation_prompt} + ) + + tracker_key = "test" if is_final_validation else "validation" + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + tracker.writer.add_image( + "Controlnet conditioning", np.asarray([validation_image]), step, dataformats="NHWC" + ) + + formatted_images = [] + for image in images: + formatted_images.append(np.asarray(image)) + + formatted_images = np.stack(formatted_images) + + tracker.writer.add_images(validation_prompt, formatted_images, step, dataformats="NHWC") + elif tracker.name == "wandb": + formatted_images = [] + + for log in image_logs: + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + + formatted_images.append(wandb.Image(validation_image, caption="Controlnet conditioning")) + + for image in images: + image = wandb.Image(image, caption=validation_prompt) + formatted_images.append(image) + + tracker.log({tracker_key: formatted_images}) + else: + logger.warning(f"image logging not implemented for {tracker.name}") + + clear_objs_and_retain_memory(pipeline) + + if not is_final_validation: + controlnet.to(accelerator.device) + + return image_logs + + +# Copied from dreambooth sd3 example +def load_text_encoders(class_one, class_two, class_three): + text_encoder_one = class_one.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision, variant=args.variant + ) + text_encoder_two = class_two.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision, variant=args.variant + ) + text_encoder_three = class_three.from_pretrained( + args.pretrained_model_name_or_path, subfolder="text_encoder_3", revision=args.revision, variant=args.variant + ) + return text_encoder_one, text_encoder_two, text_encoder_three + + +# Copied from dreambooth sd3 example +def import_model_class_from_model_name_or_path( + pretrained_model_name_or_path: str, revision: str, subfolder: str = "text_encoder" +): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, subfolder=subfolder, revision=revision + ) + model_class = text_encoder_config.architectures[0] + if model_class == "CLIPTextModelWithProjection": + from transformers import CLIPTextModelWithProjection + + return CLIPTextModelWithProjection + elif model_class == "T5EncoderModel": + from transformers import T5EncoderModel + + return T5EncoderModel + else: + raise ValueError(f"{model_class} is not supported.") + + +def save_model_card(repo_id: str, image_logs=None, base_model=str, repo_folder=None): + img_str = "" + if image_logs is not None: + img_str = "You can find some example images below.\n\n" + for i, log in enumerate(image_logs): + images = log["images"] + validation_prompt = log["validation_prompt"] + validation_image = log["validation_image"] + validation_image.save(os.path.join(repo_folder, "image_control.png")) + img_str += f"prompt: {validation_prompt}\n" + images = [validation_image] + images + image_grid(images, 1, len(images)).save(os.path.join(repo_folder, f"images_{i}.png")) + img_str += f"![images_{i})](./images_{i}.png)\n" + + model_description = f""" +# SD3 controlnet-{repo_id} + +These are controlnet weights trained on {base_model} with new type of conditioning. +The weights were trained using [ControlNet](https://github.com/lllyasviel/ControlNet) with the [SD3 diffusers trainer](https://github.com/huggingface/diffusers/blob/main/examples/controlnet/README_sd3.md). +{img_str} + +Please adhere to the licensing terms as described `[here](https://huggingface.co/stabilityai/stable-diffusion-3-medium/blob/main/LICENSE)`. +""" + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="openrail++", + base_model=base_model, + model_description=model_description, + inference=True, + ) + + tags = [ + "text-to-image", + "diffusers-training", + "diffusers", + "sd3", + "sd3-diffusers", + "controlnet", + ] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +def parse_args(input_args=None): + parser = argparse.ArgumentParser(description="Simple example of a ControlNet training script.") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--controlnet_model_name_or_path", + type=str, + default=None, + help="Path to pretrained controlnet model or model identifier from huggingface.co/models." + " If not specified controlnet weights are initialized from unet.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--output_dir", + type=str, + default="controlnet-model", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument("--num_train_epochs", type=int, default=1) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--checkpointing_steps", + type=int, + default=500, + help=( + "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via `--resume_from_checkpoint`. " + "In the case that the checkpoint is better than the final trained model, the checkpoint can also be used for inference." + "Using a checkpoint for inference requires separate loading of the original pipeline and the individual checkpointed model components." + "See https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint for step by step" + "instructions." + ), + ) + parser.add_argument( + "--checkpoints_total_limit", + type=int, + default=None, + help=("Max number of checkpoints to store."), + ) + parser.add_argument( + "--resume_from_checkpoint", + type=str, + default=None, + help=( + "Whether training should be resumed from a previous checkpoint. Use a path saved by" + ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.' + ), + ) + parser.add_argument( + "--gradient_accumulation_steps", + type=int, + default=1, + help="Number of updates steps to accumulate before performing a backward/update pass.", + ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=5e-6, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--scale_lr", + action="store_true", + default=False, + help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", + ) + parser.add_argument( + "--lr_scheduler", + type=str, + default="constant", + help=( + 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' + ' "constant", "constant_with_warmup"]' + ), + ) + parser.add_argument( + "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--lr_num_cycles", + type=int, + default=1, + help="Number of hard resets of the lr in cosine_with_restarts scheduler.", + ) + parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.") + parser.add_argument( + "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes." + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--weighting_scheme", + type=str, + default="logit_normal", + choices=["sigma_sqrt", "logit_normal", "mode", "cosmap"], + ) + parser.add_argument( + "--logit_mean", type=float, default=0.0, help="mean to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--logit_std", type=float, default=1.0, help="std to use when using the `'logit_normal'` weighting scheme." + ) + parser.add_argument( + "--mode_scale", + type=float, + default=1.29, + help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.", + ) + parser.add_argument( + "--precondition_outputs", + type=int, + default=1, + help="Flag indicating if we are preconditioning the model outputs or not as done in EDM. This affects how " + "model `target` is calculated.", + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.") + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + parser.add_argument( + "--logging_dir", + type=str, + default="logs", + help=( + "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" + " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." + ), + ) + parser.add_argument( + "--allow_tf32", + action="store_true", + help=( + "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see" + " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices" + ), + ) + parser.add_argument( + "--report_to", + type=str, + default="tensorboard", + help=( + 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' + ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' + ), + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument( + "--set_grads_to_none", + action="store_true", + help=( + "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain" + " behaviors, so disable this argument if it causes any problems. More info:" + " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html" + ), + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing the target image." + ) + parser.add_argument( + "--conditioning_image_column", + type=str, + default="conditioning_image", + help="The column of the dataset containing the controlnet conditioning image.", + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--proportion_empty_prompts", + type=float, + default=0, + help="Proportion of image prompts to be replaced with empty strings. Defaults to 0 (no prompt replacement).", + ) + parser.add_argument( + "--max_sequence_length", + type=int, + default=77, + help="Maximum sequence length to use with with the T5 text encoder", + ) + parser.add_argument( + "--validation_prompt", + type=str, + default=None, + nargs="+", + help=( + "A set of prompts evaluated every `--validation_steps` and logged to `--report_to`." + " Provide either a matching number of `--validation_image`s, a single `--validation_image`" + " to be used with all prompts, or a single prompt that will be used with all `--validation_image`s." + ), + ) + parser.add_argument( + "--validation_image", + type=str, + default=None, + nargs="+", + help=( + "A set of paths to the controlnet conditioning image be evaluated every `--validation_steps`" + " and logged to `--report_to`. Provide either a matching number of `--validation_prompt`s, a" + " a single `--validation_prompt` to be used with all `--validation_image`s, or a single" + " `--validation_image` that will be used with all `--validation_prompt`s." + ), + ) + parser.add_argument( + "--num_validation_images", + type=int, + default=4, + help="Number of images to be generated for each `--validation_image`, `--validation_prompt` pair", + ) + parser.add_argument( + "--validation_steps", + type=int, + default=100, + help=( + "Run validation every X steps. Validation consists of running the prompt" + " `args.validation_prompt` multiple times: `args.num_validation_images`" + " and logging the images." + ), + ) + parser.add_argument( + "--tracker_project_name", + type=str, + default="train_controlnet", + help=( + "The `project_name` argument passed to Accelerator.init_trackers for" + " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" + ), + ) + + if input_args is not None: + args = parser.parse_args(input_args) + else: + args = parser.parse_args() + + if args.dataset_name is None and args.train_data_dir is None: + raise ValueError("Specify either `--dataset_name` or `--train_data_dir`") + + if args.dataset_name is not None and args.train_data_dir is not None: + raise ValueError("Specify only one of `--dataset_name` or `--train_data_dir`") + + if args.proportion_empty_prompts < 0 or args.proportion_empty_prompts > 1: + raise ValueError("`--proportion_empty_prompts` must be in the range [0, 1].") + + if args.validation_prompt is not None and args.validation_image is None: + raise ValueError("`--validation_image` must be set if `--validation_prompt` is set") + + if args.validation_prompt is None and args.validation_image is not None: + raise ValueError("`--validation_prompt` must be set if `--validation_image` is set") + + if ( + args.validation_image is not None + and args.validation_prompt is not None + and len(args.validation_image) != 1 + and len(args.validation_prompt) != 1 + and len(args.validation_image) != len(args.validation_prompt) + ): + raise ValueError( + "Must provide either 1 `--validation_image`, 1 `--validation_prompt`," + " or the same number of `--validation_prompt`s and `--validation_image`s" + ) + + if args.resolution % 8 != 0: + raise ValueError( + "`--resolution` must be divisible by 8 for consistently sized encoded images between the VAE and the controlnet encoder." + ) + + return args + + +def make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator): + # Get the datasets: you can either provide your own training and evaluation files (see below) + # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub). + + # In distributed training, the load_dataset function guarantees that only one local process can concurrently + # download the dataset. + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + ) + else: + if args.train_data_dir is not None: + dataset = load_dataset( + args.train_data_dir, + cache_dir=args.cache_dir, + ) + # See more about loading custom images at + # https://huggingface.co/docs/datasets/v2.0.0/en/dataset_script + + # Preprocessing the datasets. + # We need to tokenize inputs and targets. + column_names = dataset["train"].column_names + + # 6. Get the column names for input/target. + if args.image_column is None: + image_column = column_names[0] + logger.info(f"image column defaulting to {image_column}") + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"`--image_column` value '{args.image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.caption_column is None: + caption_column = column_names[1] + logger.info(f"caption column defaulting to {caption_column}") + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"`--caption_column` value '{args.caption_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + if args.conditioning_image_column is None: + conditioning_image_column = column_names[2] + logger.info(f"conditioning image column defaulting to {conditioning_image_column}") + else: + conditioning_image_column = args.conditioning_image_column + if conditioning_image_column not in column_names: + raise ValueError( + f"`--conditioning_image_column` value '{args.conditioning_image_column}' not found in dataset columns. Dataset columns are: {', '.join(column_names)}" + ) + + def process_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if random.random() < args.proportion_empty_prompts: + captions.append("") + elif isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + return captions + + image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + conditioning_image_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution), + transforms.ToTensor(), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + images = [image_transforms(image) for image in images] + + conditioning_images = [image.convert("RGB") for image in examples[conditioning_image_column]] + conditioning_images = [conditioning_image_transforms(image) for image in conditioning_images] + + examples["pixel_values"] = images + examples["conditioning_pixel_values"] = conditioning_images + examples["prompts"] = process_captions(examples) + + return examples + + with accelerator.main_process_first(): + if args.max_train_samples is not None: + dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples)) + # Set the training transforms + train_dataset = dataset["train"].with_transform(preprocess_train) + + return train_dataset + + +def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float() + + conditioning_pixel_values = torch.stack([example["conditioning_pixel_values"] for example in examples]) + conditioning_pixel_values = conditioning_pixel_values.to(memory_format=torch.contiguous_format).float() + + prompt_embeds = torch.stack([torch.tensor(example["prompt_embeds"]) for example in examples]) + pooled_prompt_embeds = torch.stack([torch.tensor(example["pooled_prompt_embeds"]) for example in examples]) + + return { + "pixel_values": pixel_values, + "conditioning_pixel_values": conditioning_pixel_values, + "prompt_embeds": prompt_embeds, + "pooled_prompt_embeds": pooled_prompt_embeds, + } + + +# Copied from dreambooth sd3 example +def _encode_prompt_with_t5( + text_encoder, + tokenizer, + max_sequence_length, + prompt=None, + num_images_per_prompt=1, + device=None, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=max_sequence_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device))[0] + + dtype = text_encoder.dtype + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds + + +# Copied from dreambooth sd3 example +def _encode_prompt_with_clip( + text_encoder, + tokenizer, + prompt: str, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + batch_size = len(prompt) + + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=77, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds = prompt_embeds.to(dtype=text_encoder.dtype, device=device) + + _, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + return prompt_embeds, pooled_prompt_embeds + + +# Copied from dreambooth sd3 example +def encode_prompt( + text_encoders, + tokenizers, + prompt: str, + max_sequence_length, + device=None, + num_images_per_prompt: int = 1, +): + prompt = [prompt] if isinstance(prompt, str) else prompt + + clip_tokenizers = tokenizers[:2] + clip_text_encoders = text_encoders[:2] + + clip_prompt_embeds_list = [] + clip_pooled_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(clip_tokenizers, clip_text_encoders): + prompt_embeds, pooled_prompt_embeds = _encode_prompt_with_clip( + text_encoder=text_encoder, + tokenizer=tokenizer, + prompt=prompt, + device=device if device is not None else text_encoder.device, + num_images_per_prompt=num_images_per_prompt, + ) + clip_prompt_embeds_list.append(prompt_embeds) + clip_pooled_prompt_embeds_list.append(pooled_prompt_embeds) + + clip_prompt_embeds = torch.cat(clip_prompt_embeds_list, dim=-1) + pooled_prompt_embeds = torch.cat(clip_pooled_prompt_embeds_list, dim=-1) + + t5_prompt_embed = _encode_prompt_with_t5( + text_encoders[-1], + tokenizers[-1], + max_sequence_length, + prompt=prompt, + num_images_per_prompt=num_images_per_prompt, + device=device if device is not None else text_encoders[-1].device, + ) + + clip_prompt_embeds = torch.nn.functional.pad( + clip_prompt_embeds, (0, t5_prompt_embed.shape[-1] - clip_prompt_embeds.shape[-1]) + ) + prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2) + + return prompt_embeds, pooled_prompt_embeds + + +def main(args): + if args.report_to == "wandb" and args.hub_token is not None: + raise ValueError( + "You cannot use both --report_to=wandb and --hub_token due to a security risk of exposing your token." + " Please use `huggingface-cli login` to authenticate with the Hub." + ) + + if torch.backends.mps.is_available() and args.mixed_precision == "bf16": + # due to pytorch#99272, MPS does not yet support bfloat16. + raise ValueError( + "Mixed precision training with bfloat16 is not supported on MPS. Please use fp16 (recommended) or fp32 instead." + ) + + logging_dir = Path(args.output_dir, args.logging_dir) + + accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) + + accelerator = Accelerator( + gradient_accumulation_steps=args.gradient_accumulation_steps, + mixed_precision=args.mixed_precision, + log_with=args.report_to, + project_config=accelerator_project_config, + ) + + # Disable AMP for MPS. + if torch.backends.mps.is_available(): + accelerator.native_amp = False + + if args.report_to == "wandb": + if not is_wandb_available(): + raise ImportError("Make sure to install wandb if you want to use it for logging during training.") + + # Make one log on every process with the configuration for debugging. + logging.basicConfig( + format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", + datefmt="%m/%d/%Y %H:%M:%S", + level=logging.INFO, + ) + logger.info(accelerator.state, main_process_only=False) + if accelerator.is_local_main_process: + transformers.utils.logging.set_verbosity_warning() + diffusers.utils.logging.set_verbosity_info() + else: + transformers.utils.logging.set_verbosity_error() + diffusers.utils.logging.set_verbosity_error() + + # If passed along, set the training seed now. + if args.seed is not None: + set_seed(args.seed) + + # Handle the repository creation + if accelerator.is_main_process: + if args.output_dir is not None: + os.makedirs(args.output_dir, exist_ok=True) + + if args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + # Load the tokenizer + tokenizer_one = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + tokenizer_two = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_2", + revision=args.revision, + ) + tokenizer_three = T5TokenizerFast.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer_3", + revision=args.revision, + ) + + # import correct text encoder class + text_encoder_cls_one = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision + ) + text_encoder_cls_two = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_2" + ) + text_encoder_cls_three = import_model_class_from_model_name_or_path( + args.pretrained_model_name_or_path, args.revision, subfolder="text_encoder_3" + ) + + # Load scheduler and models + noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( + args.pretrained_model_name_or_path, subfolder="scheduler" + ) + noise_scheduler_copy = copy.deepcopy(noise_scheduler) + text_encoder_one, text_encoder_two, text_encoder_three = load_text_encoders( + text_encoder_cls_one, text_encoder_cls_two, text_encoder_cls_three + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + transformer = SD3Transformer2DModel.from_pretrained( + args.pretrained_model_name_or_path, subfolder="transformer", revision=args.revision, variant=args.variant + ) + + if args.controlnet_model_name_or_path: + logger.info("Loading existing controlnet weights") + controlnet = SD3ControlNetModel.from_pretrained(args.controlnet_model_name_or_path) + else: + logger.info("Initializing controlnet weights from transformer") + controlnet = SD3ControlNetModel.from_transformer(transformer) + + transformer.requires_grad_(False) + vae.requires_grad_(False) + text_encoder_one.requires_grad_(False) + text_encoder_two.requires_grad_(False) + text_encoder_three.requires_grad_(False) + controlnet.train() + + # Taken from [Sayak Paul's Diffusers PR #6511](https://github.com/huggingface/diffusers/pull/6511/files) + def unwrap_model(model): + model = accelerator.unwrap_model(model) + model = model._orig_mod if is_compiled_module(model) else model + return model + + # `accelerate` 0.16.0 will have better support for customized saving + if version.parse(accelerate.__version__) >= version.parse("0.16.0"): + # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format + def save_model_hook(models, weights, output_dir): + if accelerator.is_main_process: + i = len(weights) - 1 + + while len(weights) > 0: + weights.pop() + model = models[i] + + sub_dir = "controlnet" + model.save_pretrained(os.path.join(output_dir, sub_dir)) + + i -= 1 + + def load_model_hook(models, input_dir): + while len(models) > 0: + # pop models so that they are not loaded again + model = models.pop() + + # load diffusers style into model + load_model = SD3ControlNetModel.from_pretrained(input_dir, subfolder="controlnet") + model.register_to_config(**load_model.config) + + model.load_state_dict(load_model.state_dict()) + del load_model + + accelerator.register_save_state_pre_hook(save_model_hook) + accelerator.register_load_state_pre_hook(load_model_hook) + + if args.gradient_checkpointing: + controlnet.enable_gradient_checkpointing() + + # Check that all trainable models are in full precision + low_precision_error_string = ( + " Please make sure to always have all model weights in full float32 precision when starting training - even if" + " doing mixed precision training, copy of the weights should still be float32." + ) + + if unwrap_model(controlnet).dtype != torch.float32: + raise ValueError( + f"Controlnet loaded as datatype {unwrap_model(controlnet).dtype}. {low_precision_error_string}" + ) + + # Enable TF32 for faster training on Ampere GPUs, + # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices + if args.allow_tf32: + torch.backends.cuda.matmul.allow_tf32 = True + + if args.scale_lr: + args.learning_rate = ( + args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes + ) + + # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs + if args.use_8bit_adam: + try: + import bitsandbytes as bnb + except ImportError: + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) + + optimizer_class = bnb.optim.AdamW8bit + else: + optimizer_class = torch.optim.AdamW + + # Optimizer creation + params_to_optimize = controlnet.parameters() + optimizer = optimizer_class( + params_to_optimize, + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + ) + + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + weight_dtype = torch.float32 + if accelerator.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move vae, transformer and text_encoder to device and cast to weight_dtype + vae.to(accelerator.device, dtype=torch.float32) + transformer.to(accelerator.device, dtype=weight_dtype) + text_encoder_one.to(accelerator.device, dtype=weight_dtype) + text_encoder_two.to(accelerator.device, dtype=weight_dtype) + text_encoder_three.to(accelerator.device, dtype=weight_dtype) + + train_dataset = make_train_dataset(args, tokenizer_one, tokenizer_two, tokenizer_three, accelerator) + + tokenizers = [tokenizer_one, tokenizer_two, tokenizer_three] + text_encoders = [text_encoder_one, text_encoder_two, text_encoder_three] + + def compute_text_embeddings(batch, text_encoders, tokenizers): + with torch.no_grad(): + prompt = batch["prompts"] + prompt_embeds, pooled_prompt_embeds = encode_prompt( + text_encoders, tokenizers, prompt, args.max_sequence_length + ) + prompt_embeds = prompt_embeds.to(accelerator.device) + pooled_prompt_embeds = pooled_prompt_embeds.to(accelerator.device) + return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds} + + compute_embeddings_fn = functools.partial( + compute_text_embeddings, + text_encoders=text_encoders, + tokenizers=tokenizers, + ) + with accelerator.main_process_first(): + from datasets.fingerprint import Hasher + + # fingerprint used by the cache for the other processes to load the result + # details: https://github.com/huggingface/diffusers/pull/4038#discussion_r1266078401 + new_fingerprint = Hasher.hash(args) + train_dataset = train_dataset.map(compute_embeddings_fn, batched=True, new_fingerprint=new_fingerprint) + + clear_objs_and_retain_memory(text_encoders + tokenizers) + + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + shuffle=True, + collate_fn=collate_fn, + batch_size=args.train_batch_size, + num_workers=args.dataloader_num_workers, + ) + + # Scheduler and math around the number of training steps. + overrode_max_train_steps = False + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if args.max_train_steps is None: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + overrode_max_train_steps = True + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * accelerator.num_processes, + num_training_steps=args.max_train_steps * accelerator.num_processes, + num_cycles=args.lr_num_cycles, + power=args.lr_power, + ) + + # Prepare everything with our `accelerator`. + controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare( + controlnet, optimizer, train_dataloader, lr_scheduler + ) + + # We need to recalculate our total training steps as the size of the training dataloader may have changed. + num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) + if overrode_max_train_steps: + args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch + # Afterwards we recalculate our number of training epochs + args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + + # We need to initialize the trackers we use, and also store our configuration. + # The trackers initializes automatically on the main process. + if accelerator.is_main_process: + tracker_config = dict(vars(args)) + + # tensorboard cannot handle list types for config + tracker_config.pop("validation_prompt") + tracker_config.pop("validation_image") + + accelerator.init_trackers(args.tracker_project_name, config=tracker_config) + + # Train! + total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps + + logger.info("***** Running training *****") + logger.info(f" Num examples = {len(train_dataset)}") + logger.info(f" Num batches each epoch = {len(train_dataloader)}") + logger.info(f" Num Epochs = {args.num_train_epochs}") + logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") + logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") + logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") + logger.info(f" Total optimization steps = {args.max_train_steps}") + global_step = 0 + first_epoch = 0 + + # Potentially load in the weights and states from a previous save + if args.resume_from_checkpoint: + if args.resume_from_checkpoint != "latest": + path = os.path.basename(args.resume_from_checkpoint) + else: + # Get the most recent checkpoint + dirs = os.listdir(args.output_dir) + dirs = [d for d in dirs if d.startswith("checkpoint")] + dirs = sorted(dirs, key=lambda x: int(x.split("-")[1])) + path = dirs[-1] if len(dirs) > 0 else None + + if path is None: + accelerator.print( + f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run." + ) + args.resume_from_checkpoint = None + initial_global_step = 0 + else: + accelerator.print(f"Resuming from checkpoint {path}") + accelerator.load_state(os.path.join(args.output_dir, path)) + global_step = int(path.split("-")[1]) + + initial_global_step = global_step + first_epoch = global_step // num_update_steps_per_epoch + else: + initial_global_step = 0 + + progress_bar = tqdm( + range(0, args.max_train_steps), + initial=initial_global_step, + desc="Steps", + # Only show the progress bar once on each machine. + disable=not accelerator.is_local_main_process, + ) + + def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): + sigmas = noise_scheduler_copy.sigmas.to(device=accelerator.device, dtype=dtype) + schedule_timesteps = noise_scheduler_copy.timesteps.to(accelerator.device) + timesteps = timesteps.to(accelerator.device) + step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] + + sigma = sigmas[step_indices].flatten() + while len(sigma.shape) < n_dim: + sigma = sigma.unsqueeze(-1) + return sigma + + image_logs = None + for epoch in range(first_epoch, args.num_train_epochs): + for step, batch in enumerate(train_dataloader): + with accelerator.accumulate(controlnet): + # Convert images to latent space + pixel_values = batch["pixel_values"].to(dtype=vae.dtype) + model_input = vae.encode(pixel_values).latent_dist.sample() + model_input = (model_input - vae.config.shift_factor) * vae.config.scaling_factor + model_input = model_input.to(dtype=weight_dtype) + + # Sample noise that we'll add to the latents + noise = torch.randn_like(model_input) + bsz = model_input.shape[0] + # Sample a random timestep for each image + # for weighting schemes where we sample timesteps non-uniformly + u = compute_density_for_timestep_sampling( + weighting_scheme=args.weighting_scheme, + batch_size=bsz, + logit_mean=args.logit_mean, + logit_std=args.logit_std, + mode_scale=args.mode_scale, + ) + indices = (u * noise_scheduler_copy.config.num_train_timesteps).long() + timesteps = noise_scheduler_copy.timesteps[indices].to(device=model_input.device) + + # Add noise according to flow matching. + # zt = (1 - texp) * x + texp * z1 + sigmas = get_sigmas(timesteps, n_dim=model_input.ndim, dtype=model_input.dtype) + noisy_model_input = (1.0 - sigmas) * model_input + sigmas * noise + + # Get the text embedding for conditioning + prompt_embeds = batch["prompt_embeds"] + pooled_prompt_embeds = batch["pooled_prompt_embeds"] + + # controlnet(s) inference + controlnet_image = batch["conditioning_pixel_values"].to(dtype=weight_dtype) + controlnet_image = vae.encode(controlnet_image).latent_dist.sample() + controlnet_image = controlnet_image * vae.config.scaling_factor + + control_block_res_samples = controlnet( + hidden_states=noisy_model_input, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + controlnet_cond=controlnet_image, + return_dict=False, + )[0] + control_block_res_samples = [sample.to(dtype=weight_dtype) for sample in control_block_res_samples] + + # Predict the noise residual + model_pred = transformer( + hidden_states=noisy_model_input, + timestep=timesteps, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_res_samples, + return_dict=False, + )[0] + + # Follow: Section 5 of https://arxiv.org/abs/2206.00364. + # Preconditioning of the model outputs. + if args.precondition_outputs: + model_pred = model_pred * (-sigmas) + noisy_model_input + + # these weighting schemes use a uniform timestep sampling + # and instead post-weight the loss + weighting = compute_loss_weighting_for_sd3(weighting_scheme=args.weighting_scheme, sigmas=sigmas) + + # flow matching loss + if args.precondition_outputs: + target = model_input + else: + target = noise - model_input + + # Compute regular loss. + loss = torch.mean( + (weighting.float() * (model_pred.float() - target.float()) ** 2).reshape(target.shape[0], -1), + 1, + ) + loss = loss.mean() + + accelerator.backward(loss) + if accelerator.sync_gradients: + params_to_clip = controlnet.parameters() + accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad(set_to_none=args.set_grads_to_none) + + # Checks if the accelerator has performed an optimization step behind the scenes + if accelerator.sync_gradients: + progress_bar.update(1) + global_step += 1 + + if accelerator.is_main_process: + if global_step % args.checkpointing_steps == 0: + # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` + if args.checkpoints_total_limit is not None: + checkpoints = os.listdir(args.output_dir) + checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] + checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) + + # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints + if len(checkpoints) >= args.checkpoints_total_limit: + num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 + removing_checkpoints = checkpoints[0:num_to_remove] + + logger.info( + f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" + ) + logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") + + for removing_checkpoint in removing_checkpoints: + removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) + shutil.rmtree(removing_checkpoint) + + save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") + accelerator.save_state(save_path) + logger.info(f"Saved state to {save_path}") + + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + image_logs = log_validation( + controlnet, + args, + accelerator, + weight_dtype, + global_step, + ) + + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} + progress_bar.set_postfix(**logs) + accelerator.log(logs, step=global_step) + + if global_step >= args.max_train_steps: + break + + # Create the pipeline using using the trained modules and save it. + accelerator.wait_for_everyone() + if accelerator.is_main_process: + controlnet = unwrap_model(controlnet) + controlnet.save_pretrained(args.output_dir) + + # Run a final round of validation. + image_logs = None + if args.validation_prompt is not None: + image_logs = log_validation( + controlnet=None, + args=args, + accelerator=accelerator, + weight_dtype=weight_dtype, + step=global_step, + is_final_validation=True, + ) + + if args.push_to_hub: + save_model_card( + repo_id, + image_logs=image_logs, + base_model=args.pretrained_model_name_or_path, + repo_folder=args.output_dir, + ) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + accelerator.end_training() + + +if __name__ == "__main__": + args = parse_args() + main(args) From b19827f6b45b01551e0f1f5458073eb46dd1b4ea Mon Sep 17 00:00:00 2001 From: dianyo Date: Wed, 11 Sep 2024 12:29:15 +0800 Subject: [PATCH 3/6] Migrate the BrownianTree to BrownianInterval in DPM solver (#9335) migrate the BrownianTree to BrownianInterval Co-authored-by: YiYi Xu --- .../schedulers/scheduling_dpmsolver_sde.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py index bea6e5e075..7f2dd08157 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_sde.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_sde.py @@ -38,7 +38,20 @@ class BatchedBrownianTree: except TypeError: seed = [seed] self.batched = False - self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + self.trees = [ + torchsde.BrownianInterval( + t0=t0, + t1=t1, + size=w0.shape, + dtype=w0.dtype, + device=w0.device, + entropy=s, + tol=1e-6, + pool_size=24, + halfway_tree=True, + ) + for s in seed + ] @staticmethod def sort(a, b): From b9e2f886cd6e9182f1bf1bf7421c6363956f94c5 Mon Sep 17 00:00:00 2001 From: asfiyab-nvidia <117682710+asfiyab-nvidia@users.noreply.github.com> Date: Tue, 10 Sep 2024 22:12:36 -0700 Subject: [PATCH 4/6] FluxPosEmbed: Remove Squeeze No-op (#9409) Remove Squeeze op Signed-off-by: Asfiya Baig Co-authored-by: YiYi Xu --- src/diffusers/models/embeddings.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index eb5067c377..0b946e1878 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -690,7 +690,7 @@ class FluxPosEmbed(nn.Module): n_axes = ids.shape[-1] cos_out = [] sin_out = [] - pos = ids.squeeze().float() + pos = ids.float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): From 5e1427a7da6e878b958fd5a2422c7763a94ff02b Mon Sep 17 00:00:00 2001 From: Aryan Date: Thu, 12 Sep 2024 01:29:58 +0530 Subject: [PATCH 5/6] [docs] AnimateDiff FreeNoise (#9414) * update docs * apply suggestions from review * Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * Update docs/source/en/api/pipelines/animatediff.md Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> * apply suggestions from review --------- Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com> --- docs/source/en/api/pipelines/animatediff.md | 83 +++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/docs/source/en/api/pipelines/animatediff.md b/docs/source/en/api/pipelines/animatediff.md index 7cacad87d7..7359012803 100644 --- a/docs/source/en/api/pipelines/animatediff.md +++ b/docs/source/en/api/pipelines/animatediff.md @@ -914,6 +914,89 @@ export_to_gif(frames, "animatelcm-motion-lora.gif") +## Using FreeNoise + +[FreeNoise: Tuning-Free Longer Video Diffusion via Noise Rescheduling](https://arxiv.org/abs/2310.15169) by Haonan Qiu, Menghan Xia, Yong Zhang, Yingqing He, Xintao Wang, Ying Shan, Ziwei Liu. + +FreeNoise is a sampling mechanism that can generate longer videos with short-video generation models by employing noise-rescheduling, temporal attention over sliding windows, and weighted averaging of latent frames. It also can be used with multiple prompts to allow for interpolated video generations. More details are available in the paper. + +The currently supported AnimateDiff pipelines that can be used with FreeNoise are: +- [`AnimateDiffPipeline`] +- [`AnimateDiffControlNetPipeline`] +- [`AnimateDiffVideoToVideoPipeline`] +- [`AnimateDiffVideoToVideoControlNetPipeline`] + +In order to use FreeNoise, a single line needs to be added to the inference code after loading your pipelines. + +```diff ++ pipe.enable_free_noise() +``` + +After this, either a single prompt could be used, or multiple prompts can be passed as a dictionary of integer-string pairs. The integer keys of the dictionary correspond to the frame index at which the influence of that prompt would be maximum. Each frame index should map to a single string prompt. The prompts for intermediate frame indices, that are not passed in the dictionary, are created by interpolating between the frame prompts that are passed. By default, simple linear interpolation is used. However, you can customize this behaviour with a callback to the `prompt_interpolation_callback` parameter when enabling FreeNoise. + +Full example: + +```python +import torch +from diffusers import AutoencoderKL, AnimateDiffPipeline, LCMScheduler, MotionAdapter +from diffusers.utils import export_to_video, load_image + +# Load pipeline +dtype = torch.float16 +motion_adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=dtype) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=dtype) + +pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=motion_adapter, vae=vae, torch_dtype=dtype) +pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear") + +pipe.load_lora_weights( + "wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm_lora" +) +pipe.set_adapters(["lcm_lora"], [0.8]) + +# Enable FreeNoise for long prompt generation +pipe.enable_free_noise(context_length=16, context_stride=4) +pipe.to("cuda") + +# Can be a single prompt, or a dictionary with frame timesteps +prompt = { + 0: "A caterpillar on a leaf, high quality, photorealistic", + 40: "A caterpillar transforming into a cocoon, on a leaf, near flowers, photorealistic", + 80: "A cocoon on a leaf, flowers in the backgrond, photorealistic", + 120: "A cocoon maturing and a butterfly being born, flowers and leaves visible in the background, photorealistic", + 160: "A beautiful butterfly, vibrant colors, sitting on a leaf, flowers in the background, photorealistic", + 200: "A beautiful butterfly, flying away in a forest, photorealistic", + 240: "A cyberpunk butterfly, neon lights, glowing", +} +negative_prompt = "bad quality, worst quality, jpeg artifacts" + +# Run inference +output = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=256, + guidance_scale=2.5, + num_inference_steps=10, + generator=torch.Generator("cpu").manual_seed(0), +) + +# Save video +frames = output.frames[0] +export_to_video(frames, "output.mp4", fps=16) +``` + +### FreeNoise memory savings + +Since FreeNoise processes multiple frames together, there are parts in the modeling where the memory required exceeds that available on normal consumer GPUs. The main memory bottlenecks that we identified are spatial and temporal attention blocks, upsampling and downsampling blocks, resnet blocks and feed-forward layers. Since most of these blocks operate effectively only on the channel/embedding dimension, one can perform chunked inference across the batch dimensions. The batch dimension in AnimateDiff are either spatial (`[B x F, H x W, C]`) or temporal (`B x H x W, F, C`) in nature (note that it may seem counter-intuitive, but the batch dimension here are correct, because spatial blocks process across the `B x F` dimension while the temporal blocks process across the `B x H x W` dimension). We introduce a `SplitInferenceModule` that makes it easier to chunk across any dimension and perform inference. This saves a lot of memory but comes at the cost of requiring more time for inference. + +```diff +# Load pipeline and adapters +# ... ++ pipe.enable_free_noise_split_inference() ++ pipe.unet.enable_forward_chunking(16) +``` + +The call to `pipe.enable_free_noise_split_inference` method accepts two parameters: `spatial_split_size` (defaults to `256`) and `temporal_split_size` (defaults to `16`). These can be configured based on how much VRAM you have available. A lower split size results in lower memory usage but slower inference, whereas a larger split size results in faster inference at the cost of more memory. ## Using `from_single_file` with the MotionAdapter From 45aa8bb1877272631ac6721bac9d04ed23372651 Mon Sep 17 00:00:00 2001 From: Juan Acevedo Date: Wed, 11 Sep 2024 20:05:06 -0700 Subject: [PATCH 6/6] Ptxla sd training (#9381) * enable pxla training of stable diffusion 2.x models. * run linter/style and run pipeline test for stable diffusion and fix issues. * update xla libraries * fix read me newline. * move files to research folder. * update per comments. * rename readme. --------- Co-authored-by: Juan Acevedo Co-authored-by: Sayak Paul --- .../research_projects/pytorch_xla/README.md | 167 +++++ .../pytorch_xla/requirements.txt | 8 + .../pytorch_xla/train_text_to_image_xla.py | 669 ++++++++++++++++++ .../pipeline_stable_diffusion.py | 11 + 4 files changed, 855 insertions(+) create mode 100644 examples/research_projects/pytorch_xla/README.md create mode 100644 examples/research_projects/pytorch_xla/requirements.txt create mode 100644 examples/research_projects/pytorch_xla/train_text_to_image_xla.py diff --git a/examples/research_projects/pytorch_xla/README.md b/examples/research_projects/pytorch_xla/README.md new file mode 100644 index 0000000000..a6901d5ada --- /dev/null +++ b/examples/research_projects/pytorch_xla/README.md @@ -0,0 +1,167 @@ +# Stable Diffusion text-to-image fine-tuning using PyTorch/XLA + +The `train_text_to_image_xla.py` script shows how to fine-tune stable diffusion model on TPU devices using PyTorch/XLA. + +It has been tested on v4 and v5p TPU versions. Training code has been tested on multi-host. + +This script implements Distributed Data Parallel using GSPMD feature in XLA compiler +where we shard the input batches over the TPU devices. + +As of 9-11-2024, these are some expected step times. + +| accelerator | global batch size | step time (seconds) | +| ----------- | ----------------- | --------- | +| v5p-128 | 1024 | 0.245 | +| v5p-256 | 2048 | 0.234 | +| v5p-512 | 4096 | 0.2498 | + +## Create TPU + +To create a TPU on Google Cloud first set these environment variables: + +```bash +export TPU_NAME= +export PROJECT_ID= +export ZONE= +export ACCELERATOR_TYPE= +export RUNTIME_VERSION= +``` + +Then run the create TPU command: +```bash +gcloud alpha compute tpus tpu-vm create ${TPU_NAME} --project ${PROJECT_ID} +--zone ${ZONE} --accelerator-type ${ACCELERATOR_TYPE} --version ${RUNTIME_VERSION} +--reserved +``` + +You can also use other ways to reserve TPUs like GKE or queued resources. + +## Setup TPU environment + +Install PyTorch and PyTorch/XLA nightly versions: +```bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ +--command=' +pip3 install --pre torch==2.5.0.dev20240905+cpu torchvision==0.20.0.dev20240905+cpu --index-url https://download.pytorch.org/whl/nightly/cpu +pip3 install "torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240905-cp310-cp310-linux_x86_64.whl" -f https://storage.googleapis.com/libtpu-releases/index.html +' +``` + +Verify that PyTorch and PyTorch/XLA were installed correctly: + +```bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--project ${PROJECT_ID} --zone ${ZONE} --worker=all \ +--command='python3 -c "import torch; import torch_xla;"' +``` + +Install dependencies: +```bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ +--command=' +git clone https://github.com/huggingface/diffusers.git +cd diffusers +git checkout main +cd examples/research_projects/pytorch_xla +pip3 install -r requirements.txt +pip3 install pillow --upgrade +cd ../../.. +pip3 install .' +``` + +## Run the training job + +### Authenticate + +Run the following command to authenticate your token. + +```bash +huggingface-cli login +``` + +This script only trains the unet part of the network. The VAE and text encoder +are fixed. + +```bash +gcloud compute tpus tpu-vm ssh ${TPU_NAME} \ +--project=${PROJECT_ID} --zone=${ZONE} --worker=all \ +--command=' +export XLA_DISABLE_FUNCTIONALIZATION=1 +export PROFILE_DIR=/tmp/ +export CACHE_DIR=/tmp/ +export DATASET_NAME=lambdalabs/naruto-blip-captions +export PER_HOST_BATCH_SIZE=32 # This is known to work on TPU v4. Can set this to 64 for TPU v5p +export TRAIN_STEPS=50 +export OUTPUT_DIR=/tmp/trained-model/ +python diffusers/examples/research_projects/pytorch_xla/train_text_to_image_xla.py --pretrained_model_name_or_path=stabilityai/stable-diffusion-2-base --dataset_name=$DATASET_NAME --resolution=512 --center_crop --random_flip --train_batch_size=$PER_HOST_BATCH_SIZE --max_train_steps=$TRAIN_STEPS --learning_rate=1e-06 --mixed_precision=bf16 --profile_duration=80000 --output_dir=$OUTPUT_DIR --dataloader_num_workers=4 --loader_prefetch_size=4 --device_prefetch_size=4' + +``` + +### Environment Envs Explained + +* `XLA_DISABLE_FUNCTIONALIZATION`: To optimize the performance for AdamW optimizer. +* `PROFILE_DIR`: Specify where to put the profiling results. +* `CACHE_DIR`: Directory to store XLA compiled graphs for persistent caching. +* `DATASET_NAME`: Dataset to train the model. +* `PER_HOST_BATCH_SIZE`: Size of the batch to load per CPU host. For e.g. for a v5p-16 with 2 CPU hosts, the global batch size will be 2xPER_HOST_BATCH_SIZE. The input batch is sharded along the batch axis. +* `TRAIN_STEPS`: Total number of training steps to run the training for. +* `OUTPUT_DIR`: Directory to store the fine-tuned model. + +## Run inference using the output model + +To run inference using the output, you can simply load the model and pass it +input prompts. The first pass will compile the graph and takes longer with the following passes running much faster. + +```bash +export CACHE_DIR=/tmp/ +``` + +```python +import torch +import os +import sys +import numpy as np + +import torch_xla.core.xla_model as xm +from time import time +from diffusers import StableDiffusionPipeline +import torch_xla.runtime as xr + +CACHE_DIR = os.environ.get("CACHE_DIR", None) +if CACHE_DIR: + xr.initialize_cache(CACHE_DIR, readonly=False) + +def main(): + device = xm.xla_device() + model_path = "jffacevedo/pxla_trained_model" + pipe = StableDiffusionPipeline.from_pretrained( + model_path, + torch_dtype=torch.bfloat16 + ) + pipe.to(device) + prompt = ["A naruto with green eyes and red legs."] + start = time() + print("compiling...") + image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] + print(f"compile time: {time() - start}") + print("generate...") + start = time() + image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] + print(f"generation time (after compile) : {time() - start}") + image.save("naruto.png") + +if __name__ == '__main__': + main() +``` + +Expected Results: + +```bash +compiling... +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [10:03<00:00, 20.10s/it] +compile time: 720.656970500946 +generate... +100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 30/30 [00:01<00:00, 17.65it/s] +generation time (after compile) : 1.8461642265319824 \ No newline at end of file diff --git a/examples/research_projects/pytorch_xla/requirements.txt b/examples/research_projects/pytorch_xla/requirements.txt new file mode 100644 index 0000000000..c3ffa42f0e --- /dev/null +++ b/examples/research_projects/pytorch_xla/requirements.txt @@ -0,0 +1,8 @@ +accelerate>=0.16.0 +torchvision +transformers>=4.25.1 +datasets>=2.19.1 +ftfy +tensorboard +Jinja2 +peft==0.7.0 diff --git a/examples/research_projects/pytorch_xla/train_text_to_image_xla.py b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py new file mode 100644 index 0000000000..5d9d8c540f --- /dev/null +++ b/examples/research_projects/pytorch_xla/train_text_to_image_xla.py @@ -0,0 +1,669 @@ +import argparse +import os +import random +import time +from pathlib import Path + +import datasets +import numpy as np +import torch +import torch.nn.functional as F +import torch.utils.checkpoint +import torch_xla.core.xla_model as xm +import torch_xla.debug.profiler as xp +import torch_xla.distributed.parallel_loader as pl +import torch_xla.distributed.spmd as xs +import torch_xla.runtime as xr +from huggingface_hub import create_repo, upload_folder +from torchvision import transforms +from transformers import CLIPTextModel, CLIPTokenizer + +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + StableDiffusionPipeline, + UNet2DConditionModel, +) +from diffusers.training_utils import compute_snr +from diffusers.utils import is_wandb_available +from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card + + +if is_wandb_available(): + pass + +PROFILE_DIR = os.environ.get("PROFILE_DIR", None) +CACHE_DIR = os.environ.get("CACHE_DIR", None) +if CACHE_DIR: + xr.initialize_cache(CACHE_DIR, readonly=False) +xr.use_spmd() +DATASET_NAME_MAPPING = { + "lambdalabs/naruto-blip-captions": ("image", "text"), +} +PORT = 9012 + + +def save_model_card( + args, + repo_id: str, + repo_folder: str = None, +): + model_description = f""" +# Text-to-image finetuning - {repo_id} + +This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. \n + +## Pipeline usage + +You can use the pipeline like so: + +```python +import torch +import os +import sys +import numpy as np + +import torch_xla.core.xla_model as xm +from time import time +from typing import Tuple +from diffusers import StableDiffusionPipeline + +def main(args): + device = xm.xla_device() + model_path = + pipe = StableDiffusionPipeline.from_pretrained( + model_path, + torch_dtype=torch.bfloat16 + ) + pipe.to(device) + prompt = ["A naruto with green eyes and red legs."] + image = pipe(prompt, num_inference_steps=30, guidance_scale=7.5).images[0] + image.save("naruto.png") + +if __name__ == '__main__': + main() +``` + +## Training info + +These are the key hyperparameters used during training: + +* Steps: {args.max_train_steps} +* Learning rate: {args.learning_rate} +* Batch size: {args.train_batch_size} +* Image resolution: {args.resolution} +* Mixed-precision: {args.mixed_precision} + +""" + + model_card = load_or_create_model_card( + repo_id_or_path=repo_id, + from_training=True, + license="creativeml-openrail-m", + base_model=args.pretrained_model_name_or_path, + model_description=model_description, + inference=True, + ) + + tags = ["stable-diffusion", "stable-diffusion-diffusers", "text-to-image", "diffusers", "diffusers-training"] + model_card = populate_model_card(model_card, tags=tags) + + model_card.save(os.path.join(repo_folder, "README.md")) + + +class TrainSD: + def __init__( + self, + vae, + weight_dtype, + device, + noise_scheduler, + unet, + optimizer, + text_encoder, + dataloader, + args, + ): + self.vae = vae + self.weight_dtype = weight_dtype + self.device = device + self.noise_scheduler = noise_scheduler + self.unet = unet + self.optimizer = optimizer + self.text_encoder = text_encoder + self.args = args + self.mesh = xs.get_global_mesh() + self.dataloader = iter(dataloader) + self.global_step = 0 + + def run_optimizer(self): + self.optimizer.step() + + def start_training(self): + times = [] + last_time = time.time() + step = 0 + while True: + if self.global_step >= self.args.max_train_steps: + xm.mark_step() + break + if step == 4 and PROFILE_DIR is not None: + xm.wait_device_ops() + xp.trace_detached(f"localhost:{PORT}", PROFILE_DIR, duration_ms=args.profile_duration) + try: + batch = next(self.dataloader) + except Exception as e: + print(e) + break + loss = self.step_fn(batch["pixel_values"], batch["input_ids"]) + step_time = time.time() - last_time + if step >= 10: + times.append(step_time) + print(f"step: {step}, step_time: {step_time}") + if step % 5 == 0: + print(f"step: {step}, loss: {loss}") + last_time = time.time() + self.global_step += 1 + step += 1 + # print(f"Average step time: {sum(times)/len(times)}") + xm.wait_device_ops() + + def step_fn( + self, + pixel_values, + input_ids, + ): + with xp.Trace("model.forward"): + self.optimizer.zero_grad() + latents = self.vae.encode(pixel_values).latent_dist.sample() + latents = latents * self.vae.config.scaling_factor + noise = torch.randn_like(latents).to(self.device, dtype=self.weight_dtype) + bsz = latents.shape[0] + timesteps = torch.randint( + 0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device + ) + timesteps = timesteps.long() + + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + encoder_hidden_states = self.text_encoder(input_ids, return_dict=False)[0] + if self.args.prediction_type is not None: + # set prediction_type of scheduler if defined + self.noise_scheduler.register_to_config(prediction_type=self.args.prediction_type) + + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") + model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + with xp.Trace("model.backward"): + if self.args.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Since we predict the noise instead of x_0, the original formulation is slightly changed. + # This is discussed in Section 4.2 of the same paper. + snr = compute_snr(self.noise_scheduler, timesteps) + mse_loss_weights = torch.stack([snr, self.args.snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + if self.noise_scheduler.config.prediction_type == "epsilon": + mse_loss_weights = mse_loss_weights / snr + elif self.noise_scheduler.config.prediction_type == "v_prediction": + mse_loss_weights = mse_loss_weights / (snr + 1) + + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + loss.backward() + with xp.Trace("optimizer_step"): + self.run_optimizer() + return loss + + +def parse_args(): + parser = argparse.ArgumentParser(description="Simple example of a training script.") + parser.add_argument( + "--input_perturbation", type=float, default=0, help="The scale of input perturbation. Recommended 0.1." + ) + parser.add_argument("--profile_duration", type=int, default=10000, help="Profile duration in ms") + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + default=None, + required=True, + help="Path to pretrained model or model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + required=False, + help="Revision of pretrained model identifier from huggingface.co/models.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Variant of the model files of the pretrained model identifier from huggingface.co/models, 'e.g.' fp16", + ) + parser.add_argument( + "--dataset_name", + type=str, + default=None, + help=( + "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private," + " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem," + " or to a folder containing files that πŸ€— Datasets can understand." + ), + ) + parser.add_argument( + "--dataset_config_name", + type=str, + default=None, + help="The config of the Dataset, leave as None if there's only one config.", + ) + parser.add_argument( + "--train_data_dir", + type=str, + default=None, + help=( + "A folder containing the training data. Folder contents must follow the structure described in" + " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file" + " must exist to provide the captions for the images. Ignored if `dataset_name` is specified." + ), + ) + parser.add_argument( + "--image_column", type=str, default="image", help="The column of the dataset containing an image." + ) + parser.add_argument( + "--caption_column", + type=str, + default="text", + help="The column of the dataset containing a caption or a list of captions.", + ) + parser.add_argument( + "--max_train_samples", + type=int, + default=None, + help=( + "For debugging purposes or quicker training, truncate the number of training examples to this " + "value if set." + ), + ) + parser.add_argument( + "--output_dir", + type=str, + default="sd-model-finetuned", + help="The output directory where the model predictions and checkpoints will be written.", + ) + parser.add_argument( + "--cache_dir", + type=str, + default=None, + help="The directory where the downloaded models and datasets will be stored.", + ) + parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.") + parser.add_argument( + "--resolution", + type=int, + default=512, + help=( + "The resolution for input images, all the images in the train/validation dataset will be resized to this" + " resolution" + ), + ) + parser.add_argument( + "--center_crop", + default=False, + action="store_true", + help=( + "Whether to center crop the input images to the resolution. If not set, the images will be randomly" + " cropped. The images will be resized to the resolution first before cropping." + ), + ) + parser.add_argument( + "--random_flip", + action="store_true", + help="whether to randomly flip images horizontally", + ) + parser.add_argument( + "--train_batch_size", type=int, default=16, help="Batch size (per device) for the training dataloader." + ) + parser.add_argument( + "--max_train_steps", + type=int, + default=None, + help="Total number of training steps to perform. If provided, overrides num_train_epochs.", + ) + parser.add_argument( + "--learning_rate", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--snr_gamma", + type=float, + default=None, + help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. " + "More details here: https://arxiv.org/abs/2303.09556.", + ) + parser.add_argument( + "--non_ema_revision", + type=str, + default=None, + required=False, + help=( + "Revision of pretrained non-ema model identifier. Must be a branch, tag or git identifier of the local or" + " remote repository specified with --pretrained_model_name_or_path." + ), + ) + parser.add_argument( + "--dataloader_num_workers", + type=int, + default=0, + help=( + "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." + ), + ) + parser.add_argument( + "--loader_prefetch_size", + type=int, + default=1, + help=("Number of subprocesses to use for data loading to cpu."), + ) + parser.add_argument( + "--device_prefetch_size", + type=int, + default=1, + help=("Number of subprocesses to use for data loading to tpu from cpu. "), + ) + parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.") + parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.") + parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.") + parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer") + parser.add_argument( + "--prediction_type", + type=str, + default=None, + help="The prediction_type that shall be used for training. Choose between 'epsilon' or 'v_prediction' or leave `None`. If left to `None` the default prediction type of the scheduler: `noise_scheduler.config.prediction_type` is chosen.", + ) + parser.add_argument( + "--mixed_precision", + type=str, + default=None, + choices=["no", "fp16", "bf16"], + help=( + "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" + " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" + " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." + ), + ) + parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.") + parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.") + parser.add_argument( + "--hub_model_id", + type=str, + default=None, + help="The name of the repository to keep in sync with the local `output_dir`.", + ) + + args = parser.parse_args() + + # default to using the same revision for the non-ema model if not specified + if args.non_ema_revision is None: + args.non_ema_revision = args.revision + + return args + + +def setup_optimizer(unet, args): + optimizer_cls = torch.optim.AdamW + return optimizer_cls( + unet.parameters(), + lr=args.learning_rate, + betas=(args.adam_beta1, args.adam_beta2), + weight_decay=args.adam_weight_decay, + eps=args.adam_epsilon, + foreach=True, + ) + + +def load_dataset(args): + if args.dataset_name is not None: + # Downloading and loading a dataset from the hub. + dataset = datasets.load_dataset( + args.dataset_name, + args.dataset_config_name, + cache_dir=args.cache_dir, + data_dir=args.train_data_dir, + ) + else: + data_files = {} + if args.train_data_dir is not None: + data_files["train"] = os.path.join(args.train_data_dir, "**") + dataset = datasets.load_dataset( + "imagefolder", + data_files=data_files, + cache_dir=args.cache_dir, + ) + return dataset + + +def get_column_names(dataset, args): + column_names = dataset["train"].column_names + + dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None) + if args.image_column is None: + image_column = dataset_columns[0] if dataset_columns is not None else column_names[0] + else: + image_column = args.image_column + if image_column not in column_names: + raise ValueError( + f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}" + ) + if args.caption_column is None: + caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1] + else: + caption_column = args.caption_column + if caption_column not in column_names: + raise ValueError( + f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}" + ) + return image_column, caption_column + + +def main(args): + args = parse_args() + + _ = xp.start_server(PORT) + + num_devices = xr.global_runtime_device_count() + device_ids = np.arange(num_devices) + mesh_shape = (num_devices, 1) + mesh = xs.Mesh(device_ids, mesh_shape, ("x", "y")) + xs.set_global_mesh(mesh) + + text_encoder = CLIPTextModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="text_encoder", + revision=args.revision, + variant=args.variant, + ) + vae = AutoencoderKL.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="vae", + revision=args.revision, + variant=args.variant, + ) + + unet = UNet2DConditionModel.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="unet", + revision=args.non_ema_revision, + ) + + if xm.is_master_ordinal() and args.push_to_hub: + repo_id = create_repo( + repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token + ).repo_id + + noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") + tokenizer = CLIPTokenizer.from_pretrained( + args.pretrained_model_name_or_path, + subfolder="tokenizer", + revision=args.revision, + ) + + from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear + + unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward) + + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + # For mixed precision training we cast all non-trainable weights (vae, + # non-lora text_encoder and non-lora unet) to half-precision + # as these weights are only used for inference, keeping weights in full + # precision is not required. + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + device = xm.xla_device() + print("device: ", device) + print("weight_dtype: ", weight_dtype) + + text_encoder = text_encoder.to(device, dtype=weight_dtype) + vae = vae.to(device, dtype=weight_dtype) + unet = unet.to(device, dtype=weight_dtype) + optimizer = setup_optimizer(unet, args) + vae.requires_grad_(False) + text_encoder.requires_grad_(False) + unet.train() + + dataset = load_dataset(args) + image_column, caption_column = get_column_names(dataset, args) + + def tokenize_captions(examples, is_train=True): + captions = [] + for caption in examples[caption_column]: + if isinstance(caption, str): + captions.append(caption) + elif isinstance(caption, (list, np.ndarray)): + # take a random caption if there are multiple + captions.append(random.choice(caption) if is_train else caption[0]) + else: + raise ValueError( + f"Caption column `{caption_column}` should contain either strings or lists of strings." + ) + inputs = tokenizer( + captions, + max_length=tokenizer.model_max_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + return inputs.input_ids + + train_transforms = transforms.Compose( + [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + (transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution)), + (transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x)), + transforms.ToTensor(), + transforms.Normalize([0.5], [0.5]), + ] + ) + + def preprocess_train(examples): + images = [image.convert("RGB") for image in examples[image_column]] + examples["pixel_values"] = [train_transforms(image) for image in images] + examples["input_ids"] = tokenize_captions(examples) + return examples + + train_dataset = dataset["train"] + train_dataset.set_format("torch") + train_dataset.set_transform(preprocess_train) + + def collate_fn(examples): + pixel_values = torch.stack([example["pixel_values"] for example in examples]) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format).to(weight_dtype) + input_ids = torch.stack([example["input_ids"] for example in examples]) + return {"pixel_values": pixel_values, "input_ids": input_ids} + + g = torch.Generator() + g.manual_seed(xr.host_index()) + sampler = torch.utils.data.RandomSampler(train_dataset, replacement=True, num_samples=int(1e10), generator=g) + train_dataloader = torch.utils.data.DataLoader( + train_dataset, + sampler=sampler, + collate_fn=collate_fn, + num_workers=args.dataloader_num_workers, + batch_size=args.train_batch_size, + ) + + train_dataloader = pl.MpDeviceLoader( + train_dataloader, + device, + input_sharding={ + "pixel_values": xs.ShardingSpec(mesh, ("x", None, None, None), minibatch=True), + "input_ids": xs.ShardingSpec(mesh, ("x", None), minibatch=True), + }, + loader_prefetch_size=args.loader_prefetch_size, + device_prefetch_size=args.device_prefetch_size, + ) + + if xm.is_master_ordinal(): + print("***** Running training *****") + print(f"Instantaneous batch size per device = {args.train_batch_size}") + print( + f"Total train batch size (w. parallel, distributed & accumulation) = {args.train_batch_size * num_devices}" + ) + print(f" Total optimization steps = {args.max_train_steps}") + + trainer = TrainSD( + vae=vae, + weight_dtype=weight_dtype, + device=device, + noise_scheduler=noise_scheduler, + unet=unet, + optimizer=optimizer, + text_encoder=text_encoder, + dataloader=train_dataloader, + args=args, + ) + + trainer.start_training() + unet = trainer.unet.to("cpu") + vae = trainer.vae.to("cpu") + text_encoder = trainer.text_encoder.to("cpu") + + pipeline = StableDiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + text_encoder=text_encoder, + vae=vae, + unet=unet, + revision=args.revision, + variant=args.variant, + ) + pipeline.save_pretrained(args.output_dir) + + if xm.is_master_ordinal() and args.push_to_hub: + save_model_card(args, repo_id, repo_folder=args.output_dir) + upload_folder( + repo_id=repo_id, + folder_path=args.output_dir, + commit_message="End of training", + ignore_patterns=["step_*", "epoch_*"], + ) + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 1ca9c59169..a2bbec7b3c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -28,6 +28,7 @@ from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( USE_PEFT_BACKEND, deprecate, + is_torch_xla_available, logging, replace_example_docstring, scale_lora_layers, @@ -39,6 +40,13 @@ from .pipeline_output import StableDiffusionPipelineOutput from .safety_checker import StableDiffusionSafetyChecker +if is_torch_xla_available(): + import torch_xla.core.xla_model as xm + + XLA_AVAILABLE = True +else: + XLA_AVAILABLE = False + logger = logging.get_logger(__name__) # pylint: disable=invalid-name EXAMPLE_DOC_STRING = """ @@ -1036,6 +1044,9 @@ class StableDiffusionPipeline( step_idx = i // getattr(self.scheduler, "order", 1) callback(step_idx, t, latents) + if XLA_AVAILABLE: + xm.mark_step() + if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ 0