mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[Wuerstchen] text to image training script (#5052)
* initial script * formatting * prior trainer wip * add efficient_net_encoder * add CLIPTextModel * add prior ema support * optimizer * fix typo * add dataloader * prompt_embeds and image_embeds * intial training loop * fix output_dir * fix add_noise * accelerator check * make effnet_transforms dynamic * fix training loop * add validation logging * use loaded text_encoder * use PreTrainedTokenizerFast * load weigth from pickle * save_model_card * remove unused file * fix typos * save prior pipeilne in its own folder * fix imports * fix pipe_t2i * scale image_embeds * remove snr_gamma * format * initial lora prior training * log_validation and save * initial gradient working * remove save/load hooks * set set_attn_processor on prior_prior * add lora script * typos * use LoraLoaderMixin for prior pipeline * fix usage * make fix-copies * yse repo_id * write_lora_layers is a staitcmethod * use defualts * fix defaults * undo * Update src/diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/loaders.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add graident checkpoint support to prior * gradient_checkpointing * formatting * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/README.md Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/train_text_to_image_lora_prior.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update src/diffusers/loaders.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Update examples/wuerstchen/text_to_image/train_text_to_image_prior.py Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * use default unet and text_encoder * fix test --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
93
examples/wuerstchen/text_to_image/README.md
Normal file
93
examples/wuerstchen/text_to_image/README.md
Normal file
@@ -0,0 +1,93 @@
|
||||
# Würstchen text-to-image fine-tuning
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
|
||||
**Important**
|
||||
|
||||
To make sure you can successfully run the latest versions of the example scripts, we highly recommend **installing from source** and keeping the install up to date. To do this, execute the following steps in a new virtual environment:
|
||||
```bash
|
||||
git clone https://github.com/huggingface/diffusers
|
||||
cd diffusers
|
||||
pip install .
|
||||
```
|
||||
|
||||
Then cd into the example folder and run
|
||||
```bash
|
||||
cd examples/wuerstchen/text_to_image
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
And initialize an [🤗Accelerate](https://github.com/huggingface/accelerate/) environment with:
|
||||
|
||||
```bash
|
||||
accelerate config
|
||||
```
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so we need to be logged in and add the `--push_to_hub` flag to the training script. To log in, run:
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
## Prior training
|
||||
|
||||
You can fine-tune the Würstchen prior model with the `train_text_to_image_prior.py` script. Note that we currently support `--gradient_checkpointing` for prior model fine-tuning so you can use it for more GPU memory constrained setups.
|
||||
|
||||
<br>
|
||||
|
||||
<!-- accelerate_snippet_start -->
|
||||
```bash
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_prior.py \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME \
|
||||
--resolution=768 \
|
||||
--train_batch_size=4 \
|
||||
--gradient_accumulation_steps=4 \
|
||||
--gradient_checkpointing \
|
||||
--dataloader_num_workers=4 \
|
||||
--max_train_steps=15000 \
|
||||
--learning_rate=1e-05 \
|
||||
--max_grad_norm=1 \
|
||||
--checkpoints_total_limit=3 \
|
||||
--lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--validation_prompts="A robot pokemon, 4k photo" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="wuerstchen-prior-pokemon-model"
|
||||
```
|
||||
<!-- accelerate_snippet_end -->
|
||||
|
||||
## Training with LoRA
|
||||
|
||||
Low-Rank Adaption of Large Language Models (or LoRA) was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
|
||||
|
||||
In a nutshell, LoRA allows adapting pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114).
|
||||
- Rank-decomposition matrices have significantly fewer parameters than original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted toward new training images via a `scale` parameter.
|
||||
|
||||
|
||||
### Prior Training
|
||||
|
||||
First, you need to set up your development environment as explained in the [installation](#Running-locally-with-PyTorch) section. Make sure to set the `DATASET_NAME` environment variable. Here, we will use the [Pokemon captions dataset](https://huggingface.co/datasets/lambdalabs/pokemon-blip-captions).
|
||||
|
||||
```bash
|
||||
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
|
||||
|
||||
accelerate launch train_text_to_image_prior_lora.py \
|
||||
--mixed_precision="fp16" \
|
||||
--dataset_name=$DATASET_NAME --caption_column="text" \
|
||||
--resolution=768 \
|
||||
--train_batch_size=8 \
|
||||
--num_train_epochs=100 --checkpointing_steps=5000 \
|
||||
--learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
|
||||
--seed=42 \
|
||||
--rank=4 \
|
||||
--validation_prompt="cute dragon creature" \
|
||||
--report_to="wandb" \
|
||||
--push_to_hub \
|
||||
--output_dir="wuerstchen-prior-pokemon-lora"
|
||||
```
|
||||
0
examples/wuerstchen/text_to_image/__init__.py
Normal file
0
examples/wuerstchen/text_to_image/__init__.py
Normal file
@@ -0,0 +1,23 @@
|
||||
import torch.nn as nn
|
||||
from torchvision.models import efficientnet_v2_l, efficientnet_v2_s
|
||||
|
||||
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
||||
from diffusers.models.modeling_utils import ModelMixin
|
||||
|
||||
|
||||
class EfficientNetEncoder(ModelMixin, ConfigMixin):
|
||||
@register_to_config
|
||||
def __init__(self, c_latent=16, c_cond=1280, effnet="efficientnet_v2_s"):
|
||||
super().__init__()
|
||||
|
||||
if effnet == "efficientnet_v2_s":
|
||||
self.backbone = efficientnet_v2_s(weights="DEFAULT").features
|
||||
else:
|
||||
self.backbone = efficientnet_v2_l(weights="DEFAULT").features
|
||||
self.mapper = nn.Sequential(
|
||||
nn.Conv2d(c_cond, c_latent, kernel_size=1, bias=False),
|
||||
nn.BatchNorm2d(c_latent), # then normalize them to have mean 0 and std 1
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.mapper(self.backbone(x))
|
||||
7
examples/wuerstchen/text_to_image/requirements.txt
Normal file
7
examples/wuerstchen/text_to_image/requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
||||
accelerate>=0.16.0
|
||||
torchvision
|
||||
transformers>=4.25.1
|
||||
wandb
|
||||
huggingface-cli
|
||||
bitsandbytes
|
||||
deepspeed
|
||||
@@ -0,0 +1,888 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState, is_initialized
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from modeling_efficient_net_encoder import EfficientNetEncoder
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler, WuerstchenPriorPipeline
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.22.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: mit
|
||||
base_model: {args.pretrained_prior_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
- lora
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# LoRA Finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
{img_str}
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
"{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype}
|
||||
)
|
||||
# load lora weights from folder:
|
||||
pipeline.prior_pipe.load_lora_weights("{repo_id}", torch_dtype={args.weight_dtype})
|
||||
|
||||
image = pipeline(prompt=prompt).images[0]
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* LoRA rank: {args.rank}
|
||||
* Epochs: {args.num_train_epochs}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
wandb_info = ""
|
||||
if is_wandb_available():
|
||||
wandb_run_url = None
|
||||
if wandb.run is not None:
|
||||
wandb_run_url = wandb.run.url
|
||||
|
||||
if wandb_run_url is not None:
|
||||
wandb_info = f"""
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_card += wandb_info
|
||||
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, attn_processors, args, accelerator, weight_dtype, epoch):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.prior_prior.set_attn_processor(attn_processors)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
height=args.resolution,
|
||||
width=args.resolution,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
elif tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
|
||||
parser.add_argument(
|
||||
"--rank",
|
||||
type=int,
|
||||
default=4,
|
||||
help=("The dimension of the LoRA update matrices."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_decoder_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_prior_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen-prior",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompts",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="wuerstchen-model-finetuned-lora",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
required=False,
|
||||
help="weight decay_to_use",
|
||||
)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run validation every X epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
default="text2image-fine-tune",
|
||||
help=(
|
||||
"The `project_name` argument passed to Accelerator.init_trackers for"
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load scheduler, effnet, tokenizer, clip_model
|
||||
noise_scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
|
||||
)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
|
||||
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
|
||||
image_encoder = EfficientNetEncoder()
|
||||
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
|
||||
image_encoder.eval()
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
||||
).eval()
|
||||
|
||||
# Freeze text_encoder, cast to weight_dtype and image_encoder and move to device
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# load prior model, cast to weight_dtype and move to device
|
||||
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
prior.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# lora attn processor
|
||||
lora_attn_procs = {}
|
||||
for name in prior.attn_processors.keys():
|
||||
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=prior.config["c"], rank=args.rank)
|
||||
prior.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(prior.attn_processors)
|
||||
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
optimizer = optimizer_cls(
|
||||
lora_layers.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
text_input_ids = inputs.input_ids
|
||||
text_mask = inputs.attention_mask.bool()
|
||||
return text_input_ids, text_mask
|
||||
|
||||
effnet_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
|
||||
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
|
||||
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
||||
text_mask = torch.stack([example["text_mask"] for example in examples])
|
||||
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = dict(vars(args))
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, effnet_images = (
|
||||
batch["text_input_ids"],
|
||||
batch["text_mask"],
|
||||
batch["effnet_pixel_values"].to(weight_dtype),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
|
||||
prompt_embeds = text_encoder_output.last_hidden_state
|
||||
image_embeds = image_encoder(effnet_images)
|
||||
# scale
|
||||
image_embeds = image_embeds.add(1.0).div(42.0)
|
||||
|
||||
# Sample noise that we'll add to the image_embeds
|
||||
noise = torch.randn_like(image_embeds)
|
||||
bsz = image_embeds.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
|
||||
|
||||
# add noise to latent
|
||||
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
||||
|
||||
# Predict the noise residual and compute losscd
|
||||
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
|
||||
|
||||
# vanilla loss
|
||||
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(lora_layers.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
||||
log_validation(
|
||||
text_encoder, tokenizer, prior.attn_processors, args, accelerator, weight_dtype, global_step
|
||||
)
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
prior = prior.to(torch.float32)
|
||||
WuerstchenPriorPipeline.save_lora_weights(
|
||||
os.path.join(args.output_dir, "prior_lora"),
|
||||
unet_lora_layers=lora_layers,
|
||||
)
|
||||
|
||||
# Run a final round of inference.
|
||||
images = []
|
||||
if args.validation_prompts is not None:
|
||||
logger.info("Running inference for collecting generated images...")
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
|
||||
# load lora weights
|
||||
pipeline.prior_pipe.load_lora_weights(os.path.join(args.output_dir, "prior_lora"))
|
||||
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
width=args.resolution,
|
||||
height=args.resolution,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
925
examples/wuerstchen/text_to_image/train_text_to_image_prior.py
Normal file
925
examples/wuerstchen/text_to_image/train_text_to_image_prior.py
Normal file
@@ -0,0 +1,925 @@
|
||||
# Copyright 2023 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import accelerate
|
||||
import datasets
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.state import AcceleratorState, is_initialized
|
||||
from accelerate.utils import ProjectConfiguration, set_seed
|
||||
from datasets import load_dataset
|
||||
from huggingface_hub import create_repo, hf_hub_download, upload_folder
|
||||
from modeling_efficient_net_encoder import EfficientNetEncoder
|
||||
from packaging import version
|
||||
from torchvision import transforms
|
||||
from tqdm import tqdm
|
||||
from transformers import CLIPTextModel, PreTrainedTokenizerFast
|
||||
from transformers.utils import ContextManagers
|
||||
|
||||
from diffusers import AutoPipelineForText2Image, DDPMWuerstchenScheduler
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.pipelines.wuerstchen import DEFAULT_STAGE_C_TIMESTEPS, WuerstchenPrior
|
||||
from diffusers.training_utils import EMAModel
|
||||
from diffusers.utils import check_min_version, is_wandb_available, make_image_grid
|
||||
from diffusers.utils.logging import set_verbosity_error, set_verbosity_info
|
||||
|
||||
|
||||
if is_wandb_available():
|
||||
import wandb
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.22.0")
|
||||
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
|
||||
DATASET_NAME_MAPPING = {
|
||||
"lambdalabs/pokemon-blip-captions": ("image", "text"),
|
||||
}
|
||||
|
||||
|
||||
def save_model_card(
|
||||
args,
|
||||
repo_id: str,
|
||||
images=None,
|
||||
repo_folder=None,
|
||||
):
|
||||
img_str = ""
|
||||
if len(images) > 0:
|
||||
image_grid = make_image_grid(images, 1, len(args.validation_prompts))
|
||||
image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
|
||||
img_str += "\n"
|
||||
|
||||
yaml = f"""
|
||||
---
|
||||
license: mit
|
||||
base_model: {args.pretrained_prior_model_name_or_path}
|
||||
datasets:
|
||||
- {args.dataset_name}
|
||||
tags:
|
||||
- wuerstchen
|
||||
- text-to-image
|
||||
- diffusers
|
||||
inference: true
|
||||
---
|
||||
"""
|
||||
model_card = f"""
|
||||
# Finetuning - {repo_id}
|
||||
|
||||
This pipeline was finetuned from **{args.pretrained_prior_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: {args.validation_prompts}: \n
|
||||
{img_str}
|
||||
|
||||
## Pipeline usage
|
||||
|
||||
You can use the pipeline like so:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline
|
||||
import torch
|
||||
|
||||
pipe_prior = DiffusionPipeline.from_pretrained("{repo_id}", torch_dtype={args.weight_dtype})
|
||||
pipe_t2i = DiffusionPipeline.from_pretrained("{args.pretrained_decoder_model_name_or_path}", torch_dtype={args.weight_dtype})
|
||||
prompt = "{args.validation_prompts[0]}"
|
||||
(image_embeds,) = pipe_prior(prompt).to_tuple()
|
||||
image = pipe_t2i(image_embeddings=image_embeds, prompt=prompt).images[0]
|
||||
image.save("my_image.png")
|
||||
```
|
||||
|
||||
## Training info
|
||||
|
||||
These are the key hyperparameters used during training:
|
||||
|
||||
* Epochs: {args.num_train_epochs}
|
||||
* Learning rate: {args.learning_rate}
|
||||
* Batch size: {args.train_batch_size}
|
||||
* Gradient accumulation steps: {args.gradient_accumulation_steps}
|
||||
* Image resolution: {args.resolution}
|
||||
* Mixed-precision: {args.mixed_precision}
|
||||
|
||||
"""
|
||||
wandb_info = ""
|
||||
if is_wandb_available():
|
||||
wandb_run_url = None
|
||||
if wandb.run is not None:
|
||||
wandb_run_url = wandb.run.url
|
||||
|
||||
if wandb_run_url is not None:
|
||||
wandb_info = f"""
|
||||
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
|
||||
"""
|
||||
|
||||
model_card += wandb_info
|
||||
|
||||
with open(os.path.join(repo_folder, "README.md"), "w") as f:
|
||||
f.write(yaml + model_card)
|
||||
|
||||
|
||||
def log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, epoch):
|
||||
logger.info("Running validation... ")
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_prior=accelerator.unwrap_model(prior),
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
images = []
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
height=args.resolution,
|
||||
width=args.resolution,
|
||||
).images[0]
|
||||
|
||||
images.append(image)
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "tensorboard":
|
||||
np_images = np.stack([np.asarray(img) for img in images])
|
||||
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
|
||||
elif tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
else:
|
||||
logger.warn(f"image logging not implemented for {tracker.name}")
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return images
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Simple example of finetuning Würstchen Prior.")
|
||||
parser.add_argument(
|
||||
"--pretrained_decoder_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pretrained_prior_model_name_or_path",
|
||||
type=str,
|
||||
default="warp-ai/wuerstchen-prior",
|
||||
required=False,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
|
||||
" dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
|
||||
" or to a folder containing files that 🤗 Datasets can understand."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dataset_config_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The config of the Dataset, leave as None if there's only one config.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"A folder containing the training data. Folder contents must follow the structure described in"
|
||||
" https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
|
||||
" must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image_column", type=str, default="image", help="The column of the dataset containing an image."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--caption_column",
|
||||
type=str,
|
||||
default="text",
|
||||
help="The column of the dataset containing a caption or a list of captions.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--max_train_samples",
|
||||
type=int,
|
||||
default=None,
|
||||
help=(
|
||||
"For debugging purposes or quicker training, truncate the number of training examples to this "
|
||||
"value if set."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompts",
|
||||
type=str,
|
||||
default=None,
|
||||
nargs="+",
|
||||
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="wuerstchen-model-finetuned",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cache_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The directory where the downloaded models and datasets will be stored.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=100)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=1e-4,
|
||||
help="learning rate",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
|
||||
parser.add_argument(
|
||||
"--dataloader_num_workers",
|
||||
type=int,
|
||||
default=0,
|
||||
help=(
|
||||
"Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument(
|
||||
"--adam_weight_decay",
|
||||
type=float,
|
||||
default=0.0,
|
||||
required=False,
|
||||
help="weight decay_to_use",
|
||||
)
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoints_total_limit",
|
||||
type=int,
|
||||
default=None,
|
||||
help=("Max number of checkpoints to store."),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Run validation every X epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tracker_project_name",
|
||||
type=str,
|
||||
default="text2image-fine-tune",
|
||||
help=(
|
||||
"The `project_name` argument passed to Accelerator.init_trackers for"
|
||||
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
# Sanity checks
|
||||
if args.dataset_name is None and args.train_data_dir is None:
|
||||
raise ValueError("Need either a dataset name or a training folder.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
logging_dir = os.path.join(args.output_dir, args.logging_dir)
|
||||
accelerator_project_config = ProjectConfiguration(
|
||||
total_limit=args.checkpoints_total_limit, project_dir=args.output_dir, logging_dir=logging_dir
|
||||
)
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
project_config=accelerator_project_config,
|
||||
)
|
||||
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo_id = create_repo(
|
||||
repo_id=args.hub_model_id or Path(args.output_dir).name, exist_ok=True, token=args.hub_token
|
||||
).repo_id
|
||||
|
||||
# Load scheduler, effnet, tokenizer, clip_model
|
||||
noise_scheduler = DDPMWuerstchenScheduler()
|
||||
tokenizer = PreTrainedTokenizerFast.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="tokenizer"
|
||||
)
|
||||
|
||||
def deepspeed_zero_init_disabled_context_manager():
|
||||
"""
|
||||
returns either a context list that includes one that will disable zero.Init or an empty context list
|
||||
"""
|
||||
deepspeed_plugin = AcceleratorState().deepspeed_plugin if is_initialized() else None
|
||||
if deepspeed_plugin is None:
|
||||
return []
|
||||
|
||||
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
|
||||
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
|
||||
pretrained_checkpoint_file = hf_hub_download("dome272/wuerstchen", filename="model_v2_stage_b.pt")
|
||||
state_dict = torch.load(pretrained_checkpoint_file, map_location="cpu")
|
||||
image_encoder = EfficientNetEncoder()
|
||||
image_encoder.load_state_dict(state_dict["effnet_state_dict"])
|
||||
image_encoder.eval()
|
||||
|
||||
text_encoder = CLIPTextModel.from_pretrained(
|
||||
args.pretrained_prior_model_name_or_path, subfolder="text_encoder", torch_dtype=weight_dtype
|
||||
).eval()
|
||||
|
||||
# Freeze text_encoder and image_encoder
|
||||
text_encoder.requires_grad_(False)
|
||||
image_encoder.requires_grad_(False)
|
||||
|
||||
# load prior model
|
||||
prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
|
||||
# Create EMA for the prior
|
||||
if args.use_ema:
|
||||
ema_prior = WuerstchenPrior.from_pretrained(args.pretrained_prior_model_name_or_path, subfolder="prior")
|
||||
ema_prior = EMAModel(ema_prior.parameters(), model_cls=WuerstchenPrior, model_config=ema_prior.config)
|
||||
ema_prior.to(accelerator.device)
|
||||
|
||||
# `accelerate` 0.16.0 will have better support for customized saving
|
||||
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
|
||||
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
|
||||
def save_model_hook(models, weights, output_dir):
|
||||
if args.use_ema:
|
||||
ema_prior.save_pretrained(os.path.join(output_dir, "prior_ema"))
|
||||
|
||||
for i, model in enumerate(models):
|
||||
model.save_pretrained(os.path.join(output_dir, "prior"))
|
||||
|
||||
# make sure to pop weight so that corresponding model is not saved again
|
||||
weights.pop()
|
||||
|
||||
def load_model_hook(models, input_dir):
|
||||
if args.use_ema:
|
||||
load_model = EMAModel.from_pretrained(os.path.join(input_dir, "prior_ema"), WuerstchenPrior)
|
||||
ema_prior.load_state_dict(load_model.state_dict())
|
||||
ema_prior.to(accelerator.device)
|
||||
del load_model
|
||||
|
||||
for i in range(len(models)):
|
||||
# pop models so that they are not loaded again
|
||||
model = models.pop()
|
||||
|
||||
# load diffusers style into model
|
||||
load_model = WuerstchenPrior.from_pretrained(input_dir, subfolder="prior")
|
||||
model.register_to_config(**load_model.config)
|
||||
|
||||
model.load_state_dict(load_model.state_dict())
|
||||
del load_model
|
||||
|
||||
accelerator.register_save_state_pre_hook(save_model_hook)
|
||||
accelerator.register_load_state_pre_hook(load_model_hook)
|
||||
|
||||
if args.gradient_checkpointing:
|
||||
prior.enable_gradient_checkpointing()
|
||||
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
optimizer = optimizer_cls(
|
||||
prior.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Get the datasets: you can either provide your own training and evaluation files (see below)
|
||||
# or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).
|
||||
|
||||
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
||||
# download the dataset.
|
||||
if args.dataset_name is not None:
|
||||
# Downloading and loading a dataset from the hub.
|
||||
dataset = load_dataset(
|
||||
args.dataset_name,
|
||||
args.dataset_config_name,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
else:
|
||||
data_files = {}
|
||||
if args.train_data_dir is not None:
|
||||
data_files["train"] = os.path.join(args.train_data_dir, "**")
|
||||
dataset = load_dataset(
|
||||
"imagefolder",
|
||||
data_files=data_files,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
# See more about loading custom images at
|
||||
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize inputs and targets.
|
||||
column_names = dataset["train"].column_names
|
||||
|
||||
# Get the column names for input/target.
|
||||
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
|
||||
if args.image_column is None:
|
||||
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
|
||||
else:
|
||||
image_column = args.image_column
|
||||
if image_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
if args.caption_column is None:
|
||||
caption_column = dataset_columns[1] if dataset_columns is not None else column_names[1]
|
||||
else:
|
||||
caption_column = args.caption_column
|
||||
if caption_column not in column_names:
|
||||
raise ValueError(
|
||||
f"--caption_column' value '{args.caption_column}' needs to be one of: {', '.join(column_names)}"
|
||||
)
|
||||
|
||||
# Preprocessing the datasets.
|
||||
# We need to tokenize input captions and transform the images
|
||||
def tokenize_captions(examples, is_train=True):
|
||||
captions = []
|
||||
for caption in examples[caption_column]:
|
||||
if isinstance(caption, str):
|
||||
captions.append(caption)
|
||||
elif isinstance(caption, (list, np.ndarray)):
|
||||
# take a random caption if there are multiple
|
||||
captions.append(random.choice(caption) if is_train else caption[0])
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Caption column `{caption_column}` should contain either strings or lists of strings."
|
||||
)
|
||||
inputs = tokenizer(
|
||||
captions, max_length=tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
|
||||
)
|
||||
text_input_ids = inputs.input_ids
|
||||
text_mask = inputs.attention_mask.bool()
|
||||
return text_input_ids, text_mask
|
||||
|
||||
effnet_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR, antialias=True),
|
||||
transforms.CenterCrop(args.resolution),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
|
||||
]
|
||||
)
|
||||
|
||||
def preprocess_train(examples):
|
||||
images = [image.convert("RGB") for image in examples[image_column]]
|
||||
examples["effnet_pixel_values"] = [effnet_transforms(image) for image in images]
|
||||
examples["text_input_ids"], examples["text_mask"] = tokenize_captions(examples)
|
||||
return examples
|
||||
|
||||
with accelerator.main_process_first():
|
||||
if args.max_train_samples is not None:
|
||||
dataset["train"] = dataset["train"].shuffle(seed=args.seed).select(range(args.max_train_samples))
|
||||
# Set the training transforms
|
||||
train_dataset = dataset["train"].with_transform(preprocess_train)
|
||||
|
||||
def collate_fn(examples):
|
||||
effnet_pixel_values = torch.stack([example["effnet_pixel_values"] for example in examples])
|
||||
effnet_pixel_values = effnet_pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
text_input_ids = torch.stack([example["text_input_ids"] for example in examples])
|
||||
text_mask = torch.stack([example["text_mask"] for example in examples])
|
||||
return {"effnet_pixel_values": effnet_pixel_values, "text_input_ids": text_input_ids, "text_mask": text_mask}
|
||||
|
||||
# DataLoaders creation:
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
shuffle=True,
|
||||
collate_fn=collate_fn,
|
||||
batch_size=args.train_batch_size,
|
||||
num_workers=args.dataloader_num_workers,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
)
|
||||
|
||||
prior, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
prior, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
image_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
tracker_config = dict(vars(args))
|
||||
tracker_config.pop("validation_prompts")
|
||||
accelerator.init_trackers(args.tracker_project_name, tracker_config)
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the most recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1] if len(dirs) > 0 else None
|
||||
|
||||
if path is None:
|
||||
accelerator.print(
|
||||
f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
|
||||
)
|
||||
args.resume_from_checkpoint = None
|
||||
else:
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
prior.train()
|
||||
train_loss = 0.0
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(prior):
|
||||
# Convert images to latent space
|
||||
text_input_ids, text_mask, effnet_images = (
|
||||
batch["text_input_ids"],
|
||||
batch["text_mask"],
|
||||
batch["effnet_pixel_values"].to(weight_dtype),
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
text_encoder_output = text_encoder(text_input_ids, attention_mask=text_mask)
|
||||
prompt_embeds = text_encoder_output.last_hidden_state
|
||||
image_embeds = image_encoder(effnet_images)
|
||||
# scale
|
||||
image_embeds = image_embeds.add(1.0).div(42.0)
|
||||
|
||||
# Sample noise that we'll add to the image_embeds
|
||||
noise = torch.randn_like(image_embeds)
|
||||
bsz = image_embeds.shape[0]
|
||||
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.rand((bsz,), device=image_embeds.device, dtype=weight_dtype)
|
||||
|
||||
# add noise to latent
|
||||
noisy_latents = noise_scheduler.add_noise(image_embeds, noise, timesteps)
|
||||
|
||||
# Predict the noise residual and compute losscd
|
||||
pred_noise = prior(noisy_latents, timesteps, prompt_embeds)
|
||||
|
||||
# vanilla loss
|
||||
loss = F.mse_loss(pred_noise.float(), noise.float(), reduction="mean")
|
||||
|
||||
# Gather the losses across all processes for logging (if we use distributed training).
|
||||
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
|
||||
train_loss += avg_loss.item() / args.gradient_accumulation_steps
|
||||
|
||||
# Backpropagate
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(prior.parameters(), args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
if args.use_ema:
|
||||
ema_prior.step(prior.parameters())
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
accelerator.log({"train_loss": train_loss}, step=global_step)
|
||||
train_loss = 0.0
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
|
||||
if args.checkpoints_total_limit is not None:
|
||||
checkpoints = os.listdir(args.output_dir)
|
||||
checkpoints = [d for d in checkpoints if d.startswith("checkpoint")]
|
||||
checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1]))
|
||||
|
||||
# before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints
|
||||
if len(checkpoints) >= args.checkpoints_total_limit:
|
||||
num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint)
|
||||
shutil.rmtree(removing_checkpoint)
|
||||
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"step_loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if accelerator.is_main_process:
|
||||
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
|
||||
if args.use_ema:
|
||||
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
|
||||
ema_prior.store(prior.parameters())
|
||||
ema_prior.copy_to(prior.parameters())
|
||||
log_validation(text_encoder, tokenizer, prior, args, accelerator, weight_dtype, global_step)
|
||||
if args.use_ema:
|
||||
# Switch back to the original UNet parameters.
|
||||
ema_prior.restore(prior.parameters())
|
||||
|
||||
# Create the pipeline using the trained modules and save it.
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
prior = accelerator.unwrap_model(prior)
|
||||
if args.use_ema:
|
||||
ema_prior.copy_to(prior.parameters())
|
||||
|
||||
pipeline = AutoPipelineForText2Image.from_pretrained(
|
||||
args.pretrained_decoder_model_name_or_path,
|
||||
prior_prior=prior,
|
||||
prior_text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
prior_tokenizer=tokenizer,
|
||||
)
|
||||
pipeline.prior_pipe.save_pretrained(os.path.join(args.output_dir, "prior_pipeline"))
|
||||
|
||||
# Run a final round of inference.
|
||||
images = []
|
||||
if args.validation_prompts is not None:
|
||||
logger.info("Running inference for collecting generated images...")
|
||||
pipeline = pipeline.to(accelerator.device, torch_dtype=weight_dtype)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
if args.seed is None:
|
||||
generator = None
|
||||
else:
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
|
||||
for i in range(len(args.validation_prompts)):
|
||||
with torch.autocast("cuda"):
|
||||
image = pipeline(
|
||||
args.validation_prompts[i],
|
||||
prior_timesteps=DEFAULT_STAGE_C_TIMESTEPS,
|
||||
generator=generator,
|
||||
width=args.resolution,
|
||||
height=args.resolution,
|
||||
).images[0]
|
||||
images.append(image)
|
||||
|
||||
if args.push_to_hub:
|
||||
save_model_card(args, repo_id, images, repo_folder=args.output_dir)
|
||||
upload_folder(
|
||||
repo_id=repo_id,
|
||||
folder_path=args.output_dir,
|
||||
commit_message="End of training",
|
||||
ignore_patterns=["step_*", "epoch_*"],
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1208,7 +1208,7 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_unet(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
unet=self.unet,
|
||||
unet=getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
_pipeline=self,
|
||||
@@ -1216,7 +1216,9 @@ class LoraLoaderMixin:
|
||||
self.load_lora_into_text_encoder(
|
||||
state_dict,
|
||||
network_alphas=network_alphas,
|
||||
text_encoder=self.text_encoder,
|
||||
text_encoder=getattr(self, self.text_encoder_name)
|
||||
if not hasattr(self, "text_encoder")
|
||||
else self.text_encoder,
|
||||
lora_scale=self.lora_scale,
|
||||
low_cpu_mem_usage=low_cpu_mem_usage,
|
||||
adapter_name=adapter_name,
|
||||
@@ -1577,7 +1579,7 @@ class LoraLoaderMixin:
|
||||
"""
|
||||
low_cpu_mem_usage = low_cpu_mem_usage if low_cpu_mem_usage is not None else _LOW_CPU_MEM_USAGE_DEFAULT
|
||||
# If the serialization format is new (introduced in https://github.com/huggingface/diffusers/pull/2918),
|
||||
# then the `state_dict` keys should have `self.unet_name` and/or `self.text_encoder_name` as
|
||||
# then the `state_dict` keys should have `cls.unet_name` and/or `cls.text_encoder_name` as
|
||||
# their prefixes.
|
||||
keys = list(state_dict.keys())
|
||||
|
||||
@@ -1961,7 +1963,7 @@ class LoraLoaderMixin:
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, torch.nn.Module] = None,
|
||||
@@ -2001,7 +2003,7 @@ class LoraLoaderMixin:
|
||||
unet_lora_layers.state_dict() if isinstance(unet_lora_layers, torch.nn.Module) else unet_lora_layers
|
||||
)
|
||||
|
||||
unet_lora_state_dict = {f"{self.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
unet_lora_state_dict = {f"{cls.unet_name}.{module_name}": param for module_name, param in weights.items()}
|
||||
state_dict.update(unet_lora_state_dict)
|
||||
|
||||
if text_encoder_lora_layers is not None:
|
||||
@@ -2012,12 +2014,12 @@ class LoraLoaderMixin:
|
||||
)
|
||||
|
||||
text_encoder_lora_state_dict = {
|
||||
f"{self.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
f"{cls.text_encoder_name}.{module_name}": param for module_name, param in weights.items()
|
||||
}
|
||||
state_dict.update(text_encoder_lora_state_dict)
|
||||
|
||||
# Save the model
|
||||
self.write_lora_layers(
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
@@ -2026,6 +2028,7 @@ class LoraLoaderMixin:
|
||||
safe_serialization=safe_serialization,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def write_lora_layers(
|
||||
state_dict: Dict[str, torch.Tensor],
|
||||
save_directory: str,
|
||||
@@ -3248,7 +3251,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
|
||||
@classmethod
|
||||
def save_lora_weights(
|
||||
self,
|
||||
cls,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
||||
@@ -3299,7 +3302,7 @@ class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin):
|
||||
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
||||
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
||||
|
||||
self.write_lora_layers(
|
||||
cls.write_lora_layers(
|
||||
state_dict=state_dict,
|
||||
save_directory=save_directory,
|
||||
is_main_process=is_main_process,
|
||||
|
||||
@@ -14,16 +14,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from typing import Dict, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...loaders import UNet2DConditionLoadersMixin
|
||||
from ...models.attention_processor import (
|
||||
ADDED_KV_ATTENTION_PROCESSORS,
|
||||
CROSS_ATTENTION_PROCESSORS,
|
||||
AttentionProcessor,
|
||||
AttnAddedKVProcessor,
|
||||
AttnProcessor,
|
||||
)
|
||||
from ...models.modeling_utils import ModelMixin
|
||||
from ...utils import is_torch_version
|
||||
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
||||
|
||||
|
||||
class WuerstchenPrior(ModelMixin, ConfigMixin):
|
||||
class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
unet_name = "prior"
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
|
||||
super().__init__()
|
||||
@@ -45,6 +58,90 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
|
||||
nn.Conv2d(c, c_in * 2, kernel_size=1),
|
||||
)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
self.set_default_attn_processor()
|
||||
|
||||
@property
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
|
||||
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
||||
if hasattr(module, "get_processor"):
|
||||
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
||||
def set_attn_processor(
|
||||
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
|
||||
):
|
||||
r"""
|
||||
Sets the attention processor to use to compute attention.
|
||||
|
||||
Parameters:
|
||||
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
for **all** `Attention` layers.
|
||||
|
||||
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
||||
processor. This is strongly recommended when setting trainable attention processors.
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor, _remove_lora=_remove_lora)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
|
||||
def set_default_attn_processor(self):
|
||||
"""
|
||||
Disables custom attention processors and sets the default attention implementation.
|
||||
"""
|
||||
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnAddedKVProcessor()
|
||||
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
||||
processor = AttnProcessor()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
||||
)
|
||||
|
||||
self.set_attn_processor(processor, _remove_lora=True)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
self.gradient_checkpointing = value
|
||||
|
||||
def gen_r_embedding(self, r, max_positions=10000):
|
||||
r = r * max_positions
|
||||
half_dim = self.c_r // 2
|
||||
@@ -61,12 +158,42 @@ class WuerstchenPrior(ModelMixin, ConfigMixin):
|
||||
x = self.projection(x)
|
||||
c_embed = self.cond_mapper(c)
|
||||
r_embed = self.gen_r_embedding(r)
|
||||
for block in self.blocks:
|
||||
if isinstance(block, AttnBlock):
|
||||
x = block(x, c_embed)
|
||||
elif isinstance(block, TimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
if is_torch_version(">=", "1.11.0"):
|
||||
for block in self.blocks:
|
||||
if isinstance(block, AttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, c_embed, use_reentrant=False
|
||||
)
|
||||
elif isinstance(block, TimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(block), x, r_embed, use_reentrant=False
|
||||
)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
||||
else:
|
||||
x = block(x)
|
||||
for block in self.blocks:
|
||||
if isinstance(block, AttnBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
|
||||
elif isinstance(block, TimestepBlock):
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
|
||||
else:
|
||||
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
|
||||
else:
|
||||
for block in self.blocks:
|
||||
if isinstance(block, AttnBlock):
|
||||
x = block(x, c_embed)
|
||||
elif isinstance(block, TimestepBlock):
|
||||
x = block(x, r_embed)
|
||||
else:
|
||||
x = block(x)
|
||||
a, b = self.out(x).chunk(2, dim=1)
|
||||
return (x_in - a) / ((1 - b).abs() + 1e-5)
|
||||
|
||||
@@ -20,6 +20,7 @@ import numpy as np
|
||||
import torch
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...loaders import LoraLoaderMixin
|
||||
from ...schedulers import DDPMWuerstchenScheduler
|
||||
from ...utils import (
|
||||
BaseOutput,
|
||||
@@ -65,7 +66,7 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
|
||||
image_embeddings: Union[torch.FloatTensor, np.ndarray]
|
||||
|
||||
|
||||
class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
|
||||
"""
|
||||
Pipeline for generating image prior for Wuerstchen.
|
||||
|
||||
@@ -90,6 +91,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline):
|
||||
Default resolution for multiple images generated.
|
||||
"""
|
||||
|
||||
unet_name = "prior"
|
||||
text_encoder_name = "text_encoder"
|
||||
model_cpu_offload_seq = "text_encoder->prior"
|
||||
|
||||
def __init__(
|
||||
|
||||
@@ -211,24 +211,15 @@ class DDPMWuerstchenScheduler(SchedulerMixin, ConfigMixin):
|
||||
self,
|
||||
original_samples: torch.FloatTensor,
|
||||
noise: torch.FloatTensor,
|
||||
timesteps: torch.IntTensor,
|
||||
timesteps: torch.FloatTensor,
|
||||
) -> torch.FloatTensor:
|
||||
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
||||
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
||||
timesteps = timesteps.to(original_samples.device)
|
||||
|
||||
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
||||
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
||||
|
||||
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
||||
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
||||
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
||||
|
||||
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
||||
return noisy_samples
|
||||
device = original_samples.device
|
||||
dtype = original_samples.dtype
|
||||
alpha_cumprod = self._alpha_cumprod(timesteps, device=device).view(
|
||||
timesteps.size(0), *[1 for _ in original_samples.shape[1:]]
|
||||
)
|
||||
noisy_samples = alpha_cumprod.sqrt() * original_samples + (1 - alpha_cumprod).sqrt() * noise
|
||||
return noisy_samples.to(dtype=dtype)
|
||||
|
||||
def __len__(self):
|
||||
return self.config.num_train_timesteps
|
||||
|
||||
Reference in New Issue
Block a user