mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
[LoRA] Add LoRA training script (#1884)
* [Lora] first upload * add first lora version * upload * more * first training * up * correct * improve * finish loaders and inference * up * up * fix more * up * finish more * finish more * up * up * change year * revert year change * Change lines * Add cloneofsimo as co-author. Co-authored-by: Simo Ryu <cloneofsimo@gmail.com> * finish * fix docs * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com> * upload * finish Co-authored-by: Simo Ryu <cloneofsimo@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
committed by
GitHub
parent
ac3fc64906
commit
ed616bd8a8
@@ -90,6 +90,8 @@
|
||||
title: Configuration
|
||||
- local: api/outputs
|
||||
title: Outputs
|
||||
- local: api/loaders
|
||||
title: Loaders
|
||||
title: Main Classes
|
||||
- sections:
|
||||
- local: api/pipelines/overview
|
||||
|
||||
30
docs/source/en/api/loaders.mdx
Normal file
30
docs/source/en/api/loaders.mdx
Normal file
@@ -0,0 +1,30 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# Loaders
|
||||
|
||||
There are many weights to train adapter neural networks for diffusion models, such as
|
||||
- [Textual Inversion](./training/text_inversion.mdx)
|
||||
- [LoRA](https://github.com/cloneofsimo/lora)
|
||||
- [Hypernetworks](https://arxiv.org/abs/1609.09106)
|
||||
|
||||
Such adapter neural networks often only consist of a fraction of the number of weights compared
|
||||
to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use
|
||||
API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py).
|
||||
|
||||
**Note**: This module is still highly experimental and prone to future changes.
|
||||
|
||||
## LoaderMixins
|
||||
|
||||
### UNet2DConditionLoadersMixin
|
||||
|
||||
[[autodoc]] loaders.UNet2DConditionLoadersMixin
|
||||
@@ -1,4 +1,4 @@
|
||||
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
@@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a
|
||||
|
||||
|
||||
## Running locally with PyTorch
|
||||
|
||||
### Installing the dependencies
|
||||
|
||||
Before running the scripts, make sure to install the library's training dependencies:
|
||||
@@ -235,6 +236,102 @@ image.save("dog-bucket.png")
|
||||
|
||||
You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
|
||||
|
||||
## Training with Low-Rank Adaptation of Large Language Models (LoRA)
|
||||
|
||||
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
|
||||
|
||||
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
|
||||
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
|
||||
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
|
||||
- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter.
|
||||
|
||||
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
|
||||
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
|
||||
|
||||
### Training
|
||||
|
||||
Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
|
||||
|
||||
First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
|
||||
Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
|
||||
|
||||
Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
|
||||
|
||||
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
|
||||
|
||||
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
|
||||
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
```
|
||||
|
||||
For this example we want to directly store the trained LoRA embeddings on the Hub, so
|
||||
we need to be logged in and add the `--push_to_hub` flag.
|
||||
|
||||
```bash
|
||||
huggingface-cli login
|
||||
```
|
||||
|
||||
Now we can start training!
|
||||
|
||||
```bash
|
||||
accelerate launch train_dreambooth_lora.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 \
|
||||
--checkpointing_steps=100 \
|
||||
--learning_rate=1e-4 \
|
||||
--report_to="wandb" \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--max_train_steps=500 \
|
||||
--validation_prompt="A photo of sks dog in a bucket" \
|
||||
--validation_epochs=50 \
|
||||
--seed="0" \
|
||||
--push_to_hub
|
||||
```
|
||||
|
||||
**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
|
||||
use *1e-4* instead of the usual *2e-6*.___**
|
||||
|
||||
The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
|
||||
|
||||
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
|
||||
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
|
||||
|
||||
### Inference
|
||||
|
||||
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
|
||||
load the original pipeline:
|
||||
|
||||
```python
|
||||
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
```
|
||||
|
||||
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
|
||||
|
||||
```python
|
||||
pipe.load_attn_procs("patrickvonplaten/lora")
|
||||
```
|
||||
|
||||
Finally, we can run the model in inference.
|
||||
|
||||
```python
|
||||
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
|
||||
```
|
||||
|
||||
## Training with Flax/JAX
|
||||
|
||||
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import itertools
|
||||
|
||||
950
examples/dreambooth/train_dreambooth_lora.py
Normal file
950
examples/dreambooth/train_dreambooth_lora.py
Normal file
@@ -0,0 +1,950 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import hashlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
import datasets
|
||||
import diffusers
|
||||
import transformers
|
||||
from accelerate import Accelerator
|
||||
from accelerate.logging import get_logger
|
||||
from accelerate.utils import set_seed
|
||||
from diffusers import (
|
||||
AutoencoderKL,
|
||||
DDPMScheduler,
|
||||
DiffusionPipeline,
|
||||
DPMSolverMultistepScheduler,
|
||||
UNet2DConditionModel,
|
||||
)
|
||||
from diffusers.loaders import AttnProcsLayers
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils import check_min_version, is_wandb_available
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from huggingface_hub import HfFolder, Repository, create_repo, whoami
|
||||
from PIL import Image
|
||||
from torchvision import transforms
|
||||
from tqdm.auto import tqdm
|
||||
from transformers import AutoTokenizer, PretrainedConfig
|
||||
|
||||
|
||||
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
|
||||
check_min_version("0.12.0.dev0")
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
|
||||
text_encoder_config = PretrainedConfig.from_pretrained(
|
||||
pretrained_model_name_or_path,
|
||||
subfolder="text_encoder",
|
||||
revision=revision,
|
||||
)
|
||||
model_class = text_encoder_config.architectures[0]
|
||||
|
||||
if model_class == "CLIPTextModel":
|
||||
from transformers import CLIPTextModel
|
||||
|
||||
return CLIPTextModel
|
||||
elif model_class == "RobertaSeriesModelWithTransformation":
|
||||
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
|
||||
|
||||
return RobertaSeriesModelWithTransformation
|
||||
else:
|
||||
raise ValueError(f"{model_class} is not supported.")
|
||||
|
||||
|
||||
def parse_args(input_args=None):
|
||||
parser = argparse.ArgumentParser(description="Simple example of a training script.")
|
||||
parser.add_argument(
|
||||
"--pretrained_model_name_or_path",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="Path to pretrained model or model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--revision",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="Revision of pretrained model identifier from huggingface.co/models.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--tokenizer_name",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Pretrained tokenizer name or path if not the same as model_name",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="A folder containing the training data of instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_data_dir",
|
||||
type=str,
|
||||
default=None,
|
||||
required=False,
|
||||
help="A folder containing the training data of class images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--instance_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
required=True,
|
||||
help="The prompt with identifier specifying the instance",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--class_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt to specify images in the same class as provided instance images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_prompt",
|
||||
type=str,
|
||||
default=None,
|
||||
help="A prompt that is used during validation to verify that the model is learning.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num_validation_images",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of images that should be generated during validation with `validation_prompt`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--validation_epochs",
|
||||
type=int,
|
||||
default=50,
|
||||
help=(
|
||||
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
|
||||
" `args.validation_prompt` multiple times: `args.num_validation_images`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--with_prior_preservation",
|
||||
default=False,
|
||||
action="store_true",
|
||||
help="Flag to add prior preservation loss.",
|
||||
)
|
||||
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
|
||||
parser.add_argument(
|
||||
"--num_class_images",
|
||||
type=int,
|
||||
default=100,
|
||||
help=(
|
||||
"Minimal class images for prior preservation loss. If there are not enough images already present in"
|
||||
" class_data_dir, additional images will be sampled with class_prompt."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output_dir",
|
||||
type=str,
|
||||
default="lora-dreambooth-model",
|
||||
help="The output directory where the model predictions and checkpoints will be written.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
|
||||
parser.add_argument(
|
||||
"--resolution",
|
||||
type=int,
|
||||
default=512,
|
||||
help=(
|
||||
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
|
||||
" resolution"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
|
||||
)
|
||||
parser.add_argument("--num_train_epochs", type=int, default=1)
|
||||
parser.add_argument(
|
||||
"--max_train_steps",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpointing_steps",
|
||||
type=int,
|
||||
default=500,
|
||||
help=(
|
||||
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
|
||||
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
|
||||
" training using `--resume_from_checkpoint`."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--resume_from_checkpoint",
|
||||
type=str,
|
||||
default=None,
|
||||
help=(
|
||||
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
|
||||
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_accumulation_steps",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of updates steps to accumulate before performing a backward/update pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gradient_checkpointing",
|
||||
action="store_true",
|
||||
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning_rate",
|
||||
type=float,
|
||||
default=5e-4,
|
||||
help="Initial learning rate (after the potential warmup period) to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--scale_lr",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_scheduler",
|
||||
type=str,
|
||||
default="constant",
|
||||
help=(
|
||||
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
|
||||
' "constant", "constant_with_warmup"]'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lr_num_cycles",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
|
||||
)
|
||||
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
|
||||
parser.add_argument(
|
||||
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
|
||||
)
|
||||
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
|
||||
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
|
||||
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
|
||||
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
|
||||
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
|
||||
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
|
||||
parser.add_argument(
|
||||
"--hub_model_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The name of the repository to keep in sync with the local `output_dir`.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logging_dir",
|
||||
type=str,
|
||||
default="logs",
|
||||
help=(
|
||||
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
|
||||
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--allow_tf32",
|
||||
action="store_true",
|
||||
help=(
|
||||
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
|
||||
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--report_to",
|
||||
type=str,
|
||||
default="tensorboard",
|
||||
help=(
|
||||
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
|
||||
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mixed_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp16", "bf16"],
|
||||
help=(
|
||||
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
|
||||
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
|
||||
),
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prior_generation_precision",
|
||||
type=str,
|
||||
default=None,
|
||||
choices=["no", "fp32", "fp16", "bf16"],
|
||||
help=(
|
||||
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
|
||||
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
|
||||
),
|
||||
)
|
||||
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
|
||||
parser.add_argument(
|
||||
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
|
||||
)
|
||||
|
||||
if input_args is not None:
|
||||
args = parser.parse_args(input_args)
|
||||
else:
|
||||
args = parser.parse_args()
|
||||
|
||||
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
|
||||
if env_local_rank != -1 and env_local_rank != args.local_rank:
|
||||
args.local_rank = env_local_rank
|
||||
|
||||
if args.with_prior_preservation:
|
||||
if args.class_data_dir is None:
|
||||
raise ValueError("You must specify a data directory for class images.")
|
||||
if args.class_prompt is None:
|
||||
raise ValueError("You must specify prompt for class images.")
|
||||
else:
|
||||
# logger is not available yet
|
||||
if args.class_data_dir is not None:
|
||||
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
|
||||
if args.class_prompt is not None:
|
||||
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
|
||||
|
||||
return args
|
||||
|
||||
|
||||
class DreamBoothDataset(Dataset):
|
||||
"""
|
||||
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
|
||||
It pre-processes the images and the tokenizes prompts.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
instance_data_root,
|
||||
instance_prompt,
|
||||
tokenizer,
|
||||
class_data_root=None,
|
||||
class_prompt=None,
|
||||
size=512,
|
||||
center_crop=False,
|
||||
):
|
||||
self.size = size
|
||||
self.center_crop = center_crop
|
||||
self.tokenizer = tokenizer
|
||||
|
||||
self.instance_data_root = Path(instance_data_root)
|
||||
if not self.instance_data_root.exists():
|
||||
raise ValueError("Instance images root doesn't exists.")
|
||||
|
||||
self.instance_images_path = list(Path(instance_data_root).iterdir())
|
||||
self.num_instance_images = len(self.instance_images_path)
|
||||
self.instance_prompt = instance_prompt
|
||||
self._length = self.num_instance_images
|
||||
|
||||
if class_data_root is not None:
|
||||
self.class_data_root = Path(class_data_root)
|
||||
self.class_data_root.mkdir(parents=True, exist_ok=True)
|
||||
self.class_images_path = list(self.class_data_root.iterdir())
|
||||
self.num_class_images = len(self.class_images_path)
|
||||
self._length = max(self.num_class_images, self.num_instance_images)
|
||||
self.class_prompt = class_prompt
|
||||
else:
|
||||
self.class_data_root = None
|
||||
|
||||
self.image_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
|
||||
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.5], [0.5]),
|
||||
]
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return self._length
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
|
||||
if not instance_image.mode == "RGB":
|
||||
instance_image = instance_image.convert("RGB")
|
||||
example["instance_images"] = self.image_transforms(instance_image)
|
||||
example["instance_prompt_ids"] = self.tokenizer(
|
||||
self.instance_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
if self.class_data_root:
|
||||
class_image = Image.open(self.class_images_path[index % self.num_class_images])
|
||||
if not class_image.mode == "RGB":
|
||||
class_image = class_image.convert("RGB")
|
||||
example["class_images"] = self.image_transforms(class_image)
|
||||
example["class_prompt_ids"] = self.tokenizer(
|
||||
self.class_prompt,
|
||||
truncation=True,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
).input_ids
|
||||
|
||||
return example
|
||||
|
||||
|
||||
def collate_fn(examples, with_prior_preservation=False):
|
||||
input_ids = [example["instance_prompt_ids"] for example in examples]
|
||||
pixel_values = [example["instance_images"] for example in examples]
|
||||
|
||||
# Concat class and instance examples for prior preservation.
|
||||
# We do this to avoid doing two forward passes.
|
||||
if with_prior_preservation:
|
||||
input_ids += [example["class_prompt_ids"] for example in examples]
|
||||
pixel_values += [example["class_images"] for example in examples]
|
||||
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
|
||||
|
||||
input_ids = torch.cat(input_ids, dim=0)
|
||||
|
||||
batch = {
|
||||
"input_ids": input_ids,
|
||||
"pixel_values": pixel_values,
|
||||
}
|
||||
return batch
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
|
||||
|
||||
def __init__(self, prompt, num_samples):
|
||||
self.prompt = prompt
|
||||
self.num_samples = num_samples
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples
|
||||
|
||||
def __getitem__(self, index):
|
||||
example = {}
|
||||
example["prompt"] = self.prompt
|
||||
example["index"] = index
|
||||
return example
|
||||
|
||||
|
||||
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
|
||||
if token is None:
|
||||
token = HfFolder.get_token()
|
||||
if organization is None:
|
||||
username = whoami(token)["name"]
|
||||
return f"{username}/{model_id}"
|
||||
else:
|
||||
return f"{organization}/{model_id}"
|
||||
|
||||
|
||||
def main(args):
|
||||
logging_dir = Path(args.output_dir, args.logging_dir)
|
||||
|
||||
accelerator = Accelerator(
|
||||
gradient_accumulation_steps=args.gradient_accumulation_steps,
|
||||
mixed_precision=args.mixed_precision,
|
||||
log_with=args.report_to,
|
||||
logging_dir=logging_dir,
|
||||
)
|
||||
|
||||
if args.report_to == "wandb":
|
||||
if not is_wandb_available():
|
||||
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
|
||||
import wandb
|
||||
|
||||
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
|
||||
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
|
||||
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
|
||||
# Make one log on every process with the configuration for debugging.
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
)
|
||||
logger.info(accelerator.state, main_process_only=False)
|
||||
if accelerator.is_local_main_process:
|
||||
datasets.utils.logging.set_verbosity_warning()
|
||||
transformers.utils.logging.set_verbosity_warning()
|
||||
diffusers.utils.logging.set_verbosity_info()
|
||||
else:
|
||||
datasets.utils.logging.set_verbosity_error()
|
||||
transformers.utils.logging.set_verbosity_error()
|
||||
diffusers.utils.logging.set_verbosity_error()
|
||||
|
||||
# If passed along, set the training seed now.
|
||||
if args.seed is not None:
|
||||
set_seed(args.seed)
|
||||
|
||||
# Generate class images if prior preservation is enabled.
|
||||
if args.with_prior_preservation:
|
||||
class_images_dir = Path(args.class_data_dir)
|
||||
if not class_images_dir.exists():
|
||||
class_images_dir.mkdir(parents=True)
|
||||
cur_class_images = len(list(class_images_dir.iterdir()))
|
||||
|
||||
if cur_class_images < args.num_class_images:
|
||||
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
|
||||
if args.prior_generation_precision == "fp32":
|
||||
torch_dtype = torch.float32
|
||||
elif args.prior_generation_precision == "fp16":
|
||||
torch_dtype = torch.float16
|
||||
elif args.prior_generation_precision == "bf16":
|
||||
torch_dtype = torch.bfloat16
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
torch_dtype=torch_dtype,
|
||||
safety_checker=None,
|
||||
revision=args.revision,
|
||||
)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
num_new_images = args.num_class_images - cur_class_images
|
||||
logger.info(f"Number of class images to sample: {num_new_images}.")
|
||||
|
||||
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
|
||||
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
|
||||
|
||||
sample_dataloader = accelerator.prepare(sample_dataloader)
|
||||
pipeline.to(accelerator.device)
|
||||
|
||||
for example in tqdm(
|
||||
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
|
||||
):
|
||||
images = pipeline(example["prompt"]).images
|
||||
|
||||
for i, image in enumerate(images):
|
||||
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
|
||||
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
|
||||
image.save(image_filename)
|
||||
|
||||
del pipeline
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Handle the repository creation
|
||||
if accelerator.is_main_process:
|
||||
if args.push_to_hub:
|
||||
if args.hub_model_id is None:
|
||||
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
|
||||
else:
|
||||
repo_name = args.hub_model_id
|
||||
|
||||
repo_name = create_repo(repo_name, exist_ok=True)
|
||||
repo = Repository(args.output_dir, clone_from=repo_name)
|
||||
|
||||
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
|
||||
if "step_*" not in gitignore:
|
||||
gitignore.write("step_*\n")
|
||||
if "epoch_*" not in gitignore:
|
||||
gitignore.write("epoch_*\n")
|
||||
elif args.output_dir is not None:
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
# Load the tokenizer
|
||||
if args.tokenizer_name:
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
|
||||
elif args.pretrained_model_name_or_path:
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
subfolder="tokenizer",
|
||||
revision=args.revision,
|
||||
use_fast=False,
|
||||
)
|
||||
|
||||
# import correct text encoder class
|
||||
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
|
||||
|
||||
# Load scheduler and models
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
|
||||
text_encoder = text_encoder_cls.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
|
||||
)
|
||||
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
|
||||
unet = UNet2DConditionModel.from_pretrained(
|
||||
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
|
||||
)
|
||||
|
||||
# We only train the additional adapter LoRA layers
|
||||
vae.requires_grad_(False)
|
||||
text_encoder.requires_grad_(False)
|
||||
unet.requires_grad_(False)
|
||||
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
weight_dtype = torch.float32
|
||||
if accelerator.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif accelerator.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move unet, vae and text_encoder to device and cast to weight_dtype
|
||||
unet.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
if args.enable_xformers_memory_efficient_attention:
|
||||
if is_xformers_available():
|
||||
unet.enable_xformers_memory_efficient_attention()
|
||||
else:
|
||||
raise ValueError("xformers is not available. Make sure it is installed correctly")
|
||||
|
||||
# now we will add new LoRA weights to the attention layers
|
||||
# It's important to realize here how many attention weights will be added and of which sizes
|
||||
# The sizes of the attention layers consist only of two different variables:
|
||||
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
|
||||
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
|
||||
|
||||
# Let's first see how many attention processors we will have to set.
|
||||
# For Stable Diffusion, it should be equal to:
|
||||
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
|
||||
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
|
||||
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
|
||||
# => 32 layers
|
||||
|
||||
# Set correct lora layers
|
||||
lora_attn_procs = {}
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
unet.set_attn_processor(lora_attn_procs)
|
||||
lora_layers = AttnProcsLayers(unet.attn_processors)
|
||||
|
||||
accelerator.register_for_checkpointing(lora_layers)
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
if args.allow_tf32:
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
if args.scale_lr:
|
||||
args.learning_rate = (
|
||||
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
|
||||
)
|
||||
|
||||
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
|
||||
if args.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
|
||||
)
|
||||
|
||||
optimizer_class = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_class = torch.optim.AdamW
|
||||
|
||||
# Optimizer creation
|
||||
optimizer = optimizer_class(
|
||||
lora_layers.parameters(),
|
||||
lr=args.learning_rate,
|
||||
betas=(args.adam_beta1, args.adam_beta2),
|
||||
weight_decay=args.adam_weight_decay,
|
||||
eps=args.adam_epsilon,
|
||||
)
|
||||
|
||||
# Dataset and DataLoaders creation:
|
||||
train_dataset = DreamBoothDataset(
|
||||
instance_data_root=args.instance_data_dir,
|
||||
instance_prompt=args.instance_prompt,
|
||||
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
|
||||
class_prompt=args.class_prompt,
|
||||
tokenizer=tokenizer,
|
||||
size=args.resolution,
|
||||
center_crop=args.center_crop,
|
||||
)
|
||||
|
||||
train_dataloader = torch.utils.data.DataLoader(
|
||||
train_dataset,
|
||||
batch_size=args.train_batch_size,
|
||||
shuffle=True,
|
||||
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
# Scheduler and math around the number of training steps.
|
||||
overrode_max_train_steps = False
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if args.max_train_steps is None:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
overrode_max_train_steps = True
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
optimizer=optimizer,
|
||||
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
|
||||
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
|
||||
num_cycles=args.lr_num_cycles,
|
||||
power=args.lr_power,
|
||||
)
|
||||
|
||||
# Prepare everything with our `accelerator`.
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
|
||||
lora_layers, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
if overrode_max_train_steps:
|
||||
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
|
||||
# Afterwards we recalculate our number of training epochs
|
||||
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
|
||||
|
||||
# We need to initialize the trackers we use, and also store our configuration.
|
||||
# The trackers initializes automatically on the main process.
|
||||
if accelerator.is_main_process:
|
||||
accelerator.init_trackers("dreambooth-lora", config=vars(args))
|
||||
|
||||
# Train!
|
||||
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
|
||||
|
||||
logger.info("***** Running training *****")
|
||||
logger.info(f" Num examples = {len(train_dataset)}")
|
||||
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
|
||||
logger.info(f" Num Epochs = {args.num_train_epochs}")
|
||||
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
|
||||
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
|
||||
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
|
||||
logger.info(f" Total optimization steps = {args.max_train_steps}")
|
||||
global_step = 0
|
||||
first_epoch = 0
|
||||
|
||||
# Potentially load in the weights and states from a previous save
|
||||
if args.resume_from_checkpoint:
|
||||
if args.resume_from_checkpoint != "latest":
|
||||
path = os.path.basename(args.resume_from_checkpoint)
|
||||
else:
|
||||
# Get the mos recent checkpoint
|
||||
dirs = os.listdir(args.output_dir)
|
||||
dirs = [d for d in dirs if d.startswith("checkpoint")]
|
||||
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
|
||||
path = dirs[-1]
|
||||
accelerator.print(f"Resuming from checkpoint {path}")
|
||||
accelerator.load_state(os.path.join(args.output_dir, path))
|
||||
global_step = int(path.split("-")[1])
|
||||
|
||||
resume_global_step = global_step * args.gradient_accumulation_steps
|
||||
first_epoch = resume_global_step // num_update_steps_per_epoch
|
||||
resume_step = resume_global_step % num_update_steps_per_epoch
|
||||
|
||||
# Only show the progress bar once on each machine.
|
||||
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
|
||||
progress_bar.set_description("Steps")
|
||||
|
||||
for epoch in range(first_epoch, args.num_train_epochs):
|
||||
unet.train()
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
# Skip steps until we reach the resumed step
|
||||
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
|
||||
if step % args.gradient_accumulation_steps == 0:
|
||||
progress_bar.update(1)
|
||||
continue
|
||||
|
||||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
timesteps = timesteps.long()
|
||||
|
||||
# Add noise to the latents according to the noise magnitude at each timestep
|
||||
# (this is the forward diffusion process)
|
||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||
|
||||
# Get the text embedding for conditioning
|
||||
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
|
||||
|
||||
# Predict the noise residual
|
||||
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||
|
||||
# Get the target for loss depending on the prediction type
|
||||
if noise_scheduler.config.prediction_type == "epsilon":
|
||||
target = noise
|
||||
elif noise_scheduler.config.prediction_type == "v_prediction":
|
||||
target = noise_scheduler.get_velocity(latents, noise, timesteps)
|
||||
else:
|
||||
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
||||
|
||||
if args.with_prior_preservation:
|
||||
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
|
||||
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
|
||||
target, target_prior = torch.chunk(target, 2, dim=0)
|
||||
|
||||
# Compute instance loss
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
params_to_clip = lora_layers.parameters()
|
||||
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Checks if the accelerator has performed an optimization step behind the scenes
|
||||
if accelerator.sync_gradients:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
||||
if global_step % args.checkpointing_steps == 0:
|
||||
if accelerator.is_main_process:
|
||||
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
|
||||
accelerator.save_state(save_path)
|
||||
logger.info(f"Saved state to {save_path}")
|
||||
|
||||
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
|
||||
progress_bar.set_postfix(**logs)
|
||||
accelerator.log(logs, step=global_step)
|
||||
|
||||
if global_step >= args.max_train_steps:
|
||||
break
|
||||
|
||||
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
|
||||
logger.info(
|
||||
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
|
||||
f" {args.validation_prompt}."
|
||||
)
|
||||
# create pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path,
|
||||
unet=accelerator.unwrap_model(unet),
|
||||
text_encoder=accelerator.unwrap_model(text_encoder),
|
||||
revision=args.revision,
|
||||
torch_dtype=weight_dtype,
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
pipeline.set_progress_bar_config(disable=True)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
prompt = args.num_validation_images * [args.validation_prompt]
|
||||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"validation": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
|
||||
for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
del pipeline
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
# Save the lora layers
|
||||
accelerator.wait_for_everyone()
|
||||
if accelerator.is_main_process:
|
||||
unet = unet.to(torch.float32)
|
||||
unet.save_attn_procs(args.output_dir)
|
||||
|
||||
if args.push_to_hub:
|
||||
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
|
||||
|
||||
# Final inference
|
||||
# Load previous pipeline
|
||||
pipeline = DiffusionPipeline.from_pretrained(
|
||||
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
|
||||
)
|
||||
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
|
||||
pipeline = pipeline.to(accelerator.device)
|
||||
|
||||
# load attention processors
|
||||
pipeline.unet.load_attn_procs(args.output_dir)
|
||||
|
||||
# run inference
|
||||
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
|
||||
prompt = args.num_validation_images * [args.validation_prompt]
|
||||
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
|
||||
|
||||
for tracker in accelerator.trackers:
|
||||
if tracker.name == "wandb":
|
||||
tracker.log(
|
||||
{
|
||||
"test": [
|
||||
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
|
||||
]
|
||||
}
|
||||
)
|
||||
|
||||
accelerator.end_training()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
main(args)
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import copy
|
||||
import logging
|
||||
|
||||
@@ -1,3 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
# coding=utf-8
|
||||
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import math
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright 2020 The HuggingFace Team. All rights reserved.
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
243
src/diffusers/loaders.py
Normal file
243
src/diffusers/loaders.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .models.cross_attention import LoRACrossAttnProcessor
|
||||
from .models.modeling_utils import _get_model_file
|
||||
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
||||
|
||||
|
||||
class AttnProcsLayers(torch.nn.Module):
|
||||
def __init__(self, state_dict: Dict[str, torch.Tensor]):
|
||||
super().__init__()
|
||||
self.layers = torch.nn.ModuleList(state_dict.values())
|
||||
self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
|
||||
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
|
||||
|
||||
# we add a hook to state_dict() and load_state_dict() so that the
|
||||
# naming fits with `unet.attn_processors`
|
||||
def map_to(module, state_dict, *args, **kwargs):
|
||||
new_state_dict = {}
|
||||
for key, value in state_dict.items():
|
||||
num = int(key.split(".")[1]) # 0 is always "layers"
|
||||
new_key = key.replace(f"layers.{num}", module.mapping[num])
|
||||
new_state_dict[new_key] = value
|
||||
|
||||
return new_state_dict
|
||||
|
||||
def map_from(module, state_dict, *args, **kwargs):
|
||||
all_keys = list(state_dict.keys())
|
||||
for key in all_keys:
|
||||
replace_key = key.split(".processor")[0] + ".processor"
|
||||
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
|
||||
state_dict[new_key] = state_dict[key]
|
||||
del state_dict[key]
|
||||
|
||||
self._register_state_dict_hook(map_to)
|
||||
self._register_load_state_dict_pre_hook(map_from, with_module=True)
|
||||
|
||||
|
||||
class UNet2DConditionLoadersMixin:
|
||||
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
||||
r"""
|
||||
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
|
||||
defined in
|
||||
[cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
|
||||
and be a `torch.nn.Module` class.
|
||||
|
||||
<Tip warning={true}>
|
||||
|
||||
This function is experimental and might change in the future
|
||||
|
||||
</Tip>
|
||||
|
||||
Parameters:
|
||||
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
||||
Can be either:
|
||||
|
||||
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
|
||||
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
|
||||
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
|
||||
`./my_model_directory/`.
|
||||
- A [torch state
|
||||
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
|
||||
|
||||
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
||||
Path to a directory in which a downloaded pretrained model configuration should be cached if the
|
||||
standard cache should not be used.
|
||||
force_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
||||
cached versions if they exist.
|
||||
resume_download (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
|
||||
file exists.
|
||||
proxies (`Dict[str, str]`, *optional*):
|
||||
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
||||
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
||||
local_files_only(`bool`, *optional*, defaults to `False`):
|
||||
Whether or not to only look at local files (i.e., do not try to download the model).
|
||||
use_auth_token (`str` or *bool*, *optional*):
|
||||
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
|
||||
when running `diffusers-cli login` (stored in `~/.huggingface`).
|
||||
revision (`str`, *optional*, defaults to `"main"`):
|
||||
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
|
||||
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
|
||||
identifier allowed by git.
|
||||
subfolder (`str`, *optional*, defaults to `""`):
|
||||
In case the relevant files are located inside a subfolder of the model repo (either remote in
|
||||
huggingface.co or downloaded locally), you can specify the folder name here.
|
||||
|
||||
mirror (`str`, *optional*):
|
||||
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
|
||||
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
|
||||
Please refer to the mirror site for more information.
|
||||
|
||||
<Tip>
|
||||
|
||||
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
|
||||
models](https://huggingface.co/docs/hub/models-gated#gated-models).
|
||||
|
||||
</Tip>
|
||||
|
||||
<Tip>
|
||||
|
||||
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
|
||||
this method in a firewalled environment.
|
||||
|
||||
</Tip>
|
||||
"""
|
||||
|
||||
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
||||
force_download = kwargs.pop("force_download", False)
|
||||
resume_download = kwargs.pop("resume_download", False)
|
||||
proxies = kwargs.pop("proxies", None)
|
||||
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
|
||||
use_auth_token = kwargs.pop("use_auth_token", None)
|
||||
revision = kwargs.pop("revision", None)
|
||||
subfolder = kwargs.pop("subfolder", None)
|
||||
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
|
||||
|
||||
user_agent = {
|
||||
"file_type": "attn_procs_weights",
|
||||
"framework": "pytorch",
|
||||
}
|
||||
|
||||
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path_or_dict,
|
||||
weights_name=weight_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
revision=revision,
|
||||
subfolder=subfolder,
|
||||
user_agent=user_agent,
|
||||
)
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
else:
|
||||
state_dict = pretrained_model_name_or_path_or_dict
|
||||
|
||||
# fill attn processors
|
||||
attn_processors = {}
|
||||
|
||||
is_lora = all("lora" in k for k in state_dict.keys())
|
||||
|
||||
if is_lora:
|
||||
lora_grouped_dict = defaultdict(dict)
|
||||
for key, value in state_dict.items():
|
||||
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
||||
lora_grouped_dict[attn_processor_key][sub_key] = value
|
||||
|
||||
for key, value_dict in lora_grouped_dict.items():
|
||||
rank = value_dict["to_k_lora.down.weight"].shape[0]
|
||||
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
|
||||
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
|
||||
|
||||
attn_processors[key] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
|
||||
)
|
||||
attn_processors[key].load_state_dict(value_dict)
|
||||
|
||||
else:
|
||||
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
|
||||
|
||||
# set correct dtype & device
|
||||
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
|
||||
|
||||
# set layers
|
||||
self.set_attn_processor(attn_processors)
|
||||
|
||||
def save_attn_procs(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
is_main_process: bool = True,
|
||||
weights_name: str = LORA_WEIGHT_NAME,
|
||||
save_function: Callable = None,
|
||||
):
|
||||
r"""
|
||||
Save an attention procesor to a directory, so that it can be re-loaded using the
|
||||
`[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
|
||||
|
||||
Arguments:
|
||||
save_directory (`str` or `os.PathLike`):
|
||||
Directory to which to save. Will be created if it doesn't exist.
|
||||
is_main_process (`bool`, *optional*, defaults to `True`):
|
||||
Whether the process calling this is the main process or not. Useful when in distributed training like
|
||||
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
|
||||
the main process to avoid race conditions.
|
||||
save_function (`Callable`):
|
||||
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
|
||||
need to replace `torch.save` by another method. Can be configured with the environment variable
|
||||
`DIFFUSERS_SAVE_MODE`.
|
||||
"""
|
||||
if os.path.isfile(save_directory):
|
||||
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
||||
return
|
||||
|
||||
if save_function is None:
|
||||
save_function = torch.save
|
||||
|
||||
os.makedirs(save_directory, exist_ok=True)
|
||||
|
||||
model_to_save = AttnProcsLayers(self.attn_processors)
|
||||
|
||||
# Save the model
|
||||
state_dict = model_to_save.state_dict()
|
||||
|
||||
# Clean the folder from a previous save
|
||||
for filename in os.listdir(save_directory):
|
||||
full_filename = os.path.join(save_directory, filename)
|
||||
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
|
||||
# in distributed settings to avoid race conditions.
|
||||
weights_no_suffix = weights_name.replace(".bin", "")
|
||||
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
|
||||
os.remove(full_filename)
|
||||
|
||||
# Save the model
|
||||
save_function(state_dict, os.path.join(save_directory, weights_name))
|
||||
|
||||
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
|
||||
@@ -246,6 +246,68 @@ class CrossAttnProcessor:
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRALinearLayer(nn.Module):
|
||||
def __init__(self, in_features, out_features, rank=4):
|
||||
super().__init__()
|
||||
|
||||
if rank > min(in_features, out_features):
|
||||
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
||||
|
||||
self.down = nn.Linear(in_features, rank, bias=False)
|
||||
self.up = nn.Linear(rank, out_features, bias=False)
|
||||
self.scale = 1.0
|
||||
|
||||
nn.init.normal_(self.down.weight, std=1 / rank)
|
||||
nn.init.zeros_(self.up.weight)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
orig_dtype = hidden_states.dtype
|
||||
dtype = self.down.weight.dtype
|
||||
|
||||
down_hidden_states = self.down(hidden_states.to(dtype))
|
||||
up_hidden_states = self.up(down_hidden_states)
|
||||
|
||||
return up_hidden_states.to(orig_dtype)
|
||||
|
||||
|
||||
class LoRACrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
|
||||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
query = attn.head_to_batch_dim(query)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
||||
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CrossAttnAddedKVProcessor:
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
residual = hidden_states
|
||||
@@ -312,6 +374,41 @@ class XFormersCrossAttnProcessor:
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class LoRAXFormersCrossAttnProcessor(nn.Module):
|
||||
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
||||
super().__init__()
|
||||
|
||||
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
||||
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
||||
|
||||
def __call__(
|
||||
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
||||
):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
||||
query = attn.head_to_batch_dim(query).contiguous()
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
|
||||
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
||||
|
||||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -438,7 +438,7 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
model_file = None
|
||||
if from_flax:
|
||||
model_file = cls._get_model_file(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=FLAX_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@@ -474,7 +474,7 @@ class ModelMixin(torch.nn.Module):
|
||||
else:
|
||||
if is_safetensors_available():
|
||||
try:
|
||||
model_file = cls._get_model_file(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=SAFETENSORS_WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
|
||||
except:
|
||||
pass
|
||||
if model_file is None:
|
||||
model_file = cls._get_model_file(
|
||||
model_file = _get_model_file(
|
||||
pretrained_model_name_or_path,
|
||||
weights_name=WEIGHTS_NAME,
|
||||
cache_dir=cache_dir,
|
||||
@@ -599,92 +599,6 @@ class ModelMixin(torch.nn.Module):
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def _get_model_file(
|
||||
cls,
|
||||
pretrained_model_name_or_path,
|
||||
*,
|
||||
weights_name,
|
||||
subfolder,
|
||||
cache_dir,
|
||||
force_download,
|
||||
proxies,
|
||||
resume_download,
|
||||
local_files_only,
|
||||
use_auth_token,
|
||||
user_agent,
|
||||
revision,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
elif subfolder is not None and os.path.isfile(
|
||||
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
):
|
||||
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
|
||||
else:
|
||||
raise EnvironmentError(
|
||||
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
|
||||
)
|
||||
return model_file
|
||||
else:
|
||||
try:
|
||||
# Load from URL or cache if already cached
|
||||
model_file = hf_hub_download(
|
||||
pretrained_model_name_or_path,
|
||||
filename=weights_name,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
local_files_only=local_files_only,
|
||||
use_auth_token=use_auth_token,
|
||||
user_agent=user_agent,
|
||||
subfolder=subfolder,
|
||||
revision=revision,
|
||||
)
|
||||
return model_file
|
||||
|
||||
except RepositoryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
||||
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
||||
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
|
||||
"login`."
|
||||
)
|
||||
except RevisionNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
|
||||
"this model name. Check the model page at "
|
||||
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
||||
)
|
||||
except EntryNotFoundError:
|
||||
raise EnvironmentError(
|
||||
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
|
||||
)
|
||||
except HTTPError as err:
|
||||
raise EnvironmentError(
|
||||
"There was a specific connection error when trying to load"
|
||||
f" {pretrained_model_name_or_path}:\n{err}"
|
||||
)
|
||||
except ValueError:
|
||||
raise EnvironmentError(
|
||||
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
||||
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
||||
f" directory containing a file named {weights_name} or"
|
||||
" \nCheckout your internet connection or see how to run the library in"
|
||||
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
||||
)
|
||||
except EnvironmentError:
|
||||
raise EnvironmentError(
|
||||
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
||||
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
||||
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
||||
f"containing a file named {weights_name}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _load_pretrained_model(
|
||||
cls,
|
||||
@@ -848,7 +762,9 @@ def _get_model_file(
|
||||
revision,
|
||||
):
|
||||
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
||||
if os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(pretrained_model_name_or_path):
|
||||
return pretrained_model_name_or_path
|
||||
elif os.path.isdir(pretrained_model_name_or_path):
|
||||
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
|
||||
# Load from a PyTorch checkpoint
|
||||
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
|
||||
|
||||
@@ -19,6 +19,7 @@ import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..loaders import UNet2DConditionLoadersMixin
|
||||
from ..utils import BaseOutput, logging
|
||||
from .cross_attention import AttnProcessor
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
@@ -49,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput):
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
||||
r"""
|
||||
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
|
||||
and returns sample shaped output.
|
||||
@@ -266,17 +267,59 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attn_processor(self, processor: AttnProcessor):
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttnProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
def fn_recursive_attn_processor(module: torch.nn.Module):
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
module.set_processor(processor)
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_attn_processor(child)
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_attn_processor(module)
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `CrossAttention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -353,17 +353,59 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attn_processor(self, processor: AttnProcessor):
|
||||
@property
|
||||
def attn_processors(self) -> Dict[str, AttnProcessor]:
|
||||
r"""
|
||||
Returns:
|
||||
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
||||
indexed by its weight name.
|
||||
"""
|
||||
# set recursively
|
||||
def fn_recursive_attn_processor(module: torch.nn.Module):
|
||||
processors = {}
|
||||
|
||||
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
|
||||
if hasattr(module, "set_processor"):
|
||||
module.set_processor(processor)
|
||||
processors[f"{name}.processor"] = module.processor
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_attn_processor(child)
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
||||
|
||||
for module in self.children():
|
||||
fn_recursive_attn_processor(module)
|
||||
return processors
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_add_processors(name, module, processors)
|
||||
|
||||
return processors
|
||||
|
||||
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
|
||||
r"""
|
||||
Parameters:
|
||||
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
|
||||
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
||||
of **all** `CrossAttention` layers.
|
||||
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
|
||||
|
||||
"""
|
||||
count = len(self.attn_processors.keys())
|
||||
|
||||
if isinstance(processor, dict) and len(processor) != count:
|
||||
raise ValueError(
|
||||
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
||||
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
||||
)
|
||||
|
||||
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
||||
if hasattr(module, "set_processor"):
|
||||
if not isinstance(processor, dict):
|
||||
module.set_processor(processor)
|
||||
else:
|
||||
module.set_processor(processor.pop(f"{name}.processor"))
|
||||
|
||||
for sub_name, child in module.named_children():
|
||||
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
||||
|
||||
for name, module in self.named_children():
|
||||
fn_recursive_attn_processor(name, module, processor)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
r"""
|
||||
|
||||
@@ -58,6 +58,7 @@ from .import_utils import (
|
||||
is_transformers_available,
|
||||
is_transformers_version,
|
||||
is_unidecode_available,
|
||||
is_wandb_available,
|
||||
is_xformers_available,
|
||||
requires_backends,
|
||||
)
|
||||
|
||||
@@ -217,6 +217,13 @@ try:
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_k_diffusion_available = False
|
||||
|
||||
_wandb_available = importlib.util.find_spec("wandb") is not None
|
||||
try:
|
||||
_wandb_version = importlib_metadata.version("wandb")
|
||||
logger.debug(f"Successfully imported k-diffusion version {_wandb_version }")
|
||||
except importlib_metadata.PackageNotFoundError:
|
||||
_wandb_available = False
|
||||
|
||||
|
||||
def is_torch_available():
|
||||
return _torch_available
|
||||
@@ -274,6 +281,10 @@ def is_k_diffusion_available():
|
||||
return _k_diffusion_available
|
||||
|
||||
|
||||
def is_wandb_available():
|
||||
return _wandb_available
|
||||
|
||||
|
||||
# docstyle-ignore
|
||||
FLAX_IMPORT_ERROR = """
|
||||
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
||||
@@ -328,6 +339,12 @@ K_DIFFUSION_IMPORT_ERROR = """
|
||||
install k-diffusion`
|
||||
"""
|
||||
|
||||
# docstyle-ignore
|
||||
WANDB_IMPORT_ERROR = """
|
||||
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
|
||||
install wandb`
|
||||
"""
|
||||
|
||||
|
||||
BACKENDS_MAPPING = OrderedDict(
|
||||
[
|
||||
@@ -340,6 +357,7 @@ BACKENDS_MAPPING = OrderedDict(
|
||||
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
|
||||
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
|
||||
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
|
||||
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 Optuna, Hugging Face
|
||||
# Copyright 2022 Optuna, Hugging Face
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -20,18 +20,8 @@ import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
logging,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from parameterized import parameterized
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
@@ -218,237 +208,6 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 8,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.set_attention_slice("auto")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice("max")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice(2)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_model_slicable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
def check_slicable_dim_attr(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
assert isinstance(module.sliceable_head_dim, int)
|
||||
|
||||
for child in module.children():
|
||||
check_slicable_dim_attr(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in model.children():
|
||||
check_slicable_dim_attr(module)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(num))
|
||||
self.is_run = False
|
||||
self.number = 0
|
||||
self.counter = 0
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states += self.weight
|
||||
|
||||
self.is_run = True
|
||||
self.counter += 1
|
||||
self.number = number
|
||||
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
processor = AttnEasyProc(5.0)
|
||||
|
||||
model.set_attn_processor(processor)
|
||||
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
|
||||
|
||||
assert processor.counter == 12
|
||||
assert processor.is_run
|
||||
assert processor.number == 123
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
||||
@@ -564,310 +323,3 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
def test_forward_with_norm_groups(self):
|
||||
# not required for this model
|
||||
pass
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
revision = "fp16" if fp16 else None
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
return model
|
||||
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("max")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(2)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# there are 32 slicable layers
|
||||
slice_list = 16 * [2, 3]
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(slice_list)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
|
||||
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
|
||||
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
|
||||
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
|
||||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
|
||||
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
|
||||
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
|
||||
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
|
||||
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
|
||||
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
|
||||
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
|
||||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
|
||||
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
|
||||
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
|
||||
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
|
||||
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
|
||||
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
|
||||
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == (4, 4, 64, 64)
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
|
||||
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
|
||||
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
|
||||
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == (4, 4, 64, 64)
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
|
||||
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
|
||||
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
|
||||
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
688
tests/models/test_models_unet_2d_condition.py
Normal file
688
tests/models/test_models_unet_2d_condition.py
Normal file
@@ -0,0 +1,688 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import tempfile
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
from diffusers.models.cross_attention import LoRACrossAttnProcessor
|
||||
from diffusers.utils import (
|
||||
floats_tensor,
|
||||
load_hf_numpy,
|
||||
logging,
|
||||
require_torch_gpu,
|
||||
slow,
|
||||
torch_all_close,
|
||||
torch_device,
|
||||
)
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from parameterized import parameterized
|
||||
|
||||
from ..test_modeling_common import ModelTesterMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 8,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
@unittest.skipIf(
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_enable_works(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
model.enable_xformers_memory_efficient_attention()
|
||||
|
||||
assert (
|
||||
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
|
||||
), "xformers is not enabled"
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_model_with_attention_head_dim_tuple(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_with_use_linear_projection(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["use_linear_projection"] = True
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
|
||||
if isinstance(output, dict):
|
||||
output = output.sample
|
||||
|
||||
self.assertIsNotNone(output)
|
||||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.set_attention_slice("auto")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice("max")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice(2)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_model_slicable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
def check_slicable_dim_attr(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
assert isinstance(module.sliceable_head_dim, int)
|
||||
|
||||
for child in module.children():
|
||||
check_slicable_dim_attr(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in model.children():
|
||||
check_slicable_dim_attr(module)
|
||||
|
||||
def test_special_attn_proc(self):
|
||||
class AttnEasyProc(torch.nn.Module):
|
||||
def __init__(self, num):
|
||||
super().__init__()
|
||||
self.weight = torch.nn.Parameter(torch.tensor(num))
|
||||
self.is_run = False
|
||||
self.number = 0
|
||||
self.counter = 0
|
||||
|
||||
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
hidden_states += self.weight
|
||||
|
||||
self.is_run = True
|
||||
self.counter += 1
|
||||
self.number = number
|
||||
|
||||
return hidden_states
|
||||
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
processor = AttnEasyProc(5.0)
|
||||
|
||||
model.set_attn_processor(processor)
|
||||
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
|
||||
|
||||
assert processor.counter == 12
|
||||
assert processor.is_run
|
||||
assert processor.number == 123
|
||||
|
||||
def test_lora_processors(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample1 = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
# make sure we can set a list of attention processors
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
model.to(torch_device)
|
||||
|
||||
# test that attn processors can be set to itself
|
||||
model.set_attn_processor(model.attn_processors)
|
||||
|
||||
with torch.no_grad():
|
||||
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
|
||||
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample1 - sample2).abs().max() < 1e-4
|
||||
assert (sample3 - sample4).abs().max() < 1e-4
|
||||
|
||||
# sample 2 and sample 3 should be different
|
||||
assert (sample2 - sample3).abs().max() > 1e-4
|
||||
|
||||
def test_lora_save_load(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
torch.manual_seed(0)
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
old_sample = model(**inputs_dict).sample
|
||||
|
||||
lora_attn_procs = {}
|
||||
for name in model.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = model.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = model.config.block_out_channels[block_id]
|
||||
|
||||
lora_attn_procs[name] = LoRACrossAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||||
)
|
||||
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
|
||||
|
||||
# add 1 to weights to mock trained weights
|
||||
with torch.no_grad():
|
||||
lora_attn_procs[name].to_q_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_k_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_v_lora.up.weight += 1
|
||||
lora_attn_procs[name].to_out_lora.up.weight += 1
|
||||
|
||||
model.set_attn_processor(lora_attn_procs)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
model.save_attn_procs(tmpdirname)
|
||||
torch.manual_seed(0)
|
||||
new_model = self.model_class(**init_dict)
|
||||
new_model.to(torch_device)
|
||||
new_model.load_attn_procs(tmpdirname)
|
||||
|
||||
with torch.no_grad():
|
||||
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
|
||||
|
||||
assert (sample - new_sample).abs().max() < 1e-4
|
||||
|
||||
# LoRA and no LoRA should NOT be the same
|
||||
assert (sample - old_sample).abs().max() > 1e-4
|
||||
|
||||
|
||||
@slow
|
||||
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
||||
def get_file_format(self, seed, shape):
|
||||
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
|
||||
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return image
|
||||
|
||||
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
|
||||
revision = "fp16" if fp16 else None
|
||||
torch_dtype = torch.float16 if fp16 else torch.float32
|
||||
|
||||
model = UNet2DConditionModel.from_pretrained(
|
||||
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
|
||||
)
|
||||
model.to(torch_device).eval()
|
||||
|
||||
return model
|
||||
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("max")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(2)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# there are 32 slicable layers
|
||||
slice_list = 16 * [2, 3]
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(slice_list)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
return hidden_states
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
|
||||
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
|
||||
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
|
||||
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
|
||||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
|
||||
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
|
||||
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
|
||||
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
|
||||
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
|
||||
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
|
||||
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
|
||||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
|
||||
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
|
||||
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
|
||||
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
|
||||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
|
||||
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
|
||||
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
|
||||
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == (4, 4, 64, 64)
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
|
||||
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
|
||||
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
|
||||
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == (4, 4, 64, 64)
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
|
||||
@parameterized.expand(
|
||||
[
|
||||
# fmt: off
|
||||
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
|
||||
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
|
||||
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
|
||||
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
|
||||
# fmt: on
|
||||
]
|
||||
)
|
||||
@require_torch_gpu
|
||||
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
|
||||
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
|
||||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
assert sample.shape == latents.shape
|
||||
|
||||
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
|
||||
expected_output_slice = torch.tensor(expected_slice)
|
||||
|
||||
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2021 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2020 The HuggingFace Inc. team.
|
||||
# Copyright 2022 The HuggingFace Inc. team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
|
||||
Reference in New Issue
Block a user