1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[LoRA] Add LoRA training script (#1884)

* [Lora] first upload

* add first lora version

* upload

* more

* first training

* up

* correct

* improve

* finish loaders and inference

* up

* up

* fix more

* up

* finish more

* finish more

* up

* up

* change year

* revert year change

* Change lines

* Add cloneofsimo as co-author.

Co-authored-by: Simo Ryu <cloneofsimo@gmail.com>

* finish

* fix docs

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>

* upload

* finish

Co-authored-by: Simo Ryu <cloneofsimo@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Patrick von Platen
2023-01-18 18:05:51 +01:00
committed by GitHub
parent ac3fc64906
commit ed616bd8a8
24 changed files with 2287 additions and 663 deletions

View File

@@ -90,6 +90,8 @@
title: Configuration
- local: api/outputs
title: Outputs
- local: api/loaders
title: Loaders
title: Main Classes
- sections:
- local: api/pipelines/overview

View File

@@ -0,0 +1,30 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# Loaders
There are many weights to train adapter neural networks for diffusion models, such as
- [Textual Inversion](./training/text_inversion.mdx)
- [LoRA](https://github.com/cloneofsimo/lora)
- [Hypernetworks](https://arxiv.org/abs/1609.09106)
Such adapter neural networks often only consist of a fraction of the number of weights compared
to the pretrained model and as such are very portable. The Diffusers library offers an easy-to-use
API to load such adapter neural networks via the [`loaders.py` module](https://github.com/huggingface/diffusers/blob/main/src/diffusers/loaders.py).
**Note**: This module is still highly experimental and prone to future changes.
## LoaderMixins
### UNet2DConditionLoadersMixin
[[autodoc]] loaders.UNet2DConditionLoadersMixin

View File

@@ -1,4 +1,4 @@
<!--Copyright 2020 The HuggingFace Team. All rights reserved.
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

View File

@@ -5,6 +5,7 @@ The `train_dreambooth.py` script shows how to implement the training procedure a
## Running locally with PyTorch
### Installing the dependencies
Before running the scripts, make sure to install the library's training dependencies:
@@ -235,6 +236,102 @@ image.save("dog-bucket.png")
You can also perform inference from one of the checkpoints saved during the training process, if you used the `--checkpointing_steps` argument. Please, refer to [the documentation](https://huggingface.co/docs/diffusers/main/en/training/dreambooth#performing-inference-using-a-saved-checkpoint) to see how to do it.
## Training with Low-Rank Adaptation of Large Language Models (LoRA)
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*
In a nutshell, LoRA allows to adapt pretrained models by adding pairs of rank-decomposition matrices to existing weights and **only** training those newly added weights. This has a couple of advantages:
- Previous pretrained weights are kept frozen so that the model is not prone to [catastrophic forgetting](https://www.pnas.org/doi/10.1073/pnas.1611835114)
- Rank-decomposition matrices have significantly fewer parameters than the original model, which means that trained LoRA weights are easily portable.
- LoRA attention layers allow to control to which extent the model is adapted torwards new training images via a `scale` parameter.
[cloneofsimo](https://github.com/cloneofsimo) was the first to try out LoRA training for Stable Diffusion in
the popular [lora](https://github.com/cloneofsimo/lora) GitHub repository.
### Training
Let's get started with a simple example. We will re-use the dog example of the [previous section](#dog-toy-example).
First, you need to set-up your dreambooth training example as is explained in the [installation section](#Installing-the-dependencies).
Next, let's download the dog dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. Make sure to set `INSTANCE_DIR` to the name of your directory further below. This will be our training data.
Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https://huggingface.co/runwayml/stable-diffusion-v1-5).
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
**___Note: It is quite useful to monitor the training progress by regularly generating sample images during training. [wandb](https://docs.wandb.ai/quickstart) is a nice solution to easily see generating images during training. All you need to do is to run `pip install wandb` before training and pass `--report_to="wandb"` to automatically log images.___**
```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export OUTPUT_DIR="path-to-save-model"
```
For this example we want to directly store the trained LoRA embeddings on the Hub, so
we need to be logged in and add the `--push_to_hub` flag.
```bash
huggingface-cli login
```
Now we can start training!
```bash
accelerate launch train_dreambooth_lora.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--output_dir=$OUTPUT_DIR \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 \
--checkpointing_steps=100 \
--learning_rate=1e-4 \
--report_to="wandb" \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=50 \
--seed="0" \
--push_to_hub
```
**___Note: When using LoRA we can use a much higher learning rate compared to vanilla dreambooth. Here we
use *1e-4* instead of the usual *2e-6*.___**
The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dreambooth_dog_example](https://huggingface.co/patrickvonplaten/lora_dreambooth_dog_example). **___Note: [The final weights](https://huggingface.co/patrickvonplaten/lora/blob/main/pytorch_attn_procs.bin) are only 3 MB in size which is orders of magnitudes smaller than the original model.**
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
### Inference
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
load the original pipeline:
```python
from diffusers import DiffusionPipeline, DPMSolverMultistepScheduler
import torch
pipe = DiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
```
Next, we can load the adapter layers into the UNet with the [`load_attn_procs` function](https://huggingface.co/docs/diffusers/api/loaders#diffusers.loaders.UNet2DConditionLoadersMixin.load_attn_procs).
```python
pipe.load_attn_procs("patrickvonplaten/lora")
```
Finally, we can run the model in inference.
```python
image = pipe("A picture of a sks dog in a bucket", num_inference_steps=25).images[0]
```
## Training with Flax/JAX
For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import hashlib
import itertools

View File

@@ -0,0 +1,950 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import hashlib
import logging
import math
import os
import warnings
from pathlib import Path
from typing import Optional
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import Dataset
import datasets
import diffusers
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import set_seed
from diffusers import (
AutoencoderKL,
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available
from huggingface_hub import HfFolder, Repository, create_repo, whoami
from PIL import Image
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import AutoTokenizer, PretrainedConfig
# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.12.0.dev0")
logger = get_logger(__name__)
def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
text_encoder_config = PretrainedConfig.from_pretrained(
pretrained_model_name_or_path,
subfolder="text_encoder",
revision=revision,
)
model_class = text_encoder_config.architectures[0]
if model_class == "CLIPTextModel":
from transformers import CLIPTextModel
return CLIPTextModel
elif model_class == "RobertaSeriesModelWithTransformation":
from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
return RobertaSeriesModelWithTransformation
else:
raise ValueError(f"{model_class} is not supported.")
def parse_args(input_args=None):
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model_name_or_path",
type=str,
default=None,
required=True,
help="Path to pretrained model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--revision",
type=str,
default=None,
required=False,
help="Revision of pretrained model identifier from huggingface.co/models.",
)
parser.add_argument(
"--tokenizer_name",
type=str,
default=None,
help="Pretrained tokenizer name or path if not the same as model_name",
)
parser.add_argument(
"--instance_data_dir",
type=str,
default=None,
required=True,
help="A folder containing the training data of instance images.",
)
parser.add_argument(
"--class_data_dir",
type=str,
default=None,
required=False,
help="A folder containing the training data of class images.",
)
parser.add_argument(
"--instance_prompt",
type=str,
default=None,
required=True,
help="The prompt with identifier specifying the instance",
)
parser.add_argument(
"--class_prompt",
type=str,
default=None,
help="The prompt to specify images in the same class as provided instance images.",
)
parser.add_argument(
"--validation_prompt",
type=str,
default=None,
help="A prompt that is used during validation to verify that the model is learning.",
)
parser.add_argument(
"--num_validation_images",
type=int,
default=4,
help="Number of images that should be generated during validation with `validation_prompt`.",
)
parser.add_argument(
"--validation_epochs",
type=int,
default=50,
help=(
"Run dreambooth validation every X epochs. Dreambooth validation consists of running the prompt"
" `args.validation_prompt` multiple times: `args.num_validation_images`."
),
)
parser.add_argument(
"--with_prior_preservation",
default=False,
action="store_true",
help="Flag to add prior preservation loss.",
)
parser.add_argument("--prior_loss_weight", type=float, default=1.0, help="The weight of prior preservation loss.")
parser.add_argument(
"--num_class_images",
type=int,
default=100,
help=(
"Minimal class images for prior preservation loss. If there are not enough images already present in"
" class_data_dir, additional images will be sampled with class_prompt."
),
)
parser.add_argument(
"--output_dir",
type=str,
default="lora-dreambooth-model",
help="The output directory where the model predictions and checkpoints will be written.",
)
parser.add_argument("--seed", type=int, default=None, help="A seed for reproducible training.")
parser.add_argument(
"--resolution",
type=int,
default=512,
help=(
"The resolution for input images, all the images in the train/validation dataset will be resized to this"
" resolution"
),
)
parser.add_argument(
"--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution"
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
parser.add_argument(
"--sample_batch_size", type=int, default=4, help="Batch size (per device) for sampling images."
)
parser.add_argument("--num_train_epochs", type=int, default=1)
parser.add_argument(
"--max_train_steps",
type=int,
default=None,
help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
)
parser.add_argument(
"--checkpointing_steps",
type=int,
default=500,
help=(
"Save a checkpoint of the training state every X updates. These checkpoints can be used both as final"
" checkpoints in case they are better than the last checkpoint, and are also suitable for resuming"
" training using `--resume_from_checkpoint`."
),
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
default=None,
help=(
"Whether training should be resumed from a previous checkpoint. Use a path saved by"
' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
),
)
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before performing a backward/update pass.",
)
parser.add_argument(
"--gradient_checkpointing",
action="store_true",
help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
)
parser.add_argument(
"--learning_rate",
type=float,
default=5e-4,
help="Initial learning rate (after the potential warmup period) to use.",
)
parser.add_argument(
"--scale_lr",
action="store_true",
default=False,
help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
)
parser.add_argument(
"--lr_scheduler",
type=str,
default="constant",
help=(
'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
' "constant", "constant_with_warmup"]'
),
)
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--lr_num_cycles",
type=int,
default=1,
help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
)
parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")
parser.add_argument(
"--hub_model_id",
type=str,
default=None,
help="The name of the repository to keep in sync with the local `output_dir`.",
)
parser.add_argument(
"--logging_dir",
type=str,
default="logs",
help=(
"[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
" *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
),
)
parser.add_argument(
"--allow_tf32",
action="store_true",
help=(
"Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
" https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
),
)
parser.add_argument(
"--report_to",
type=str,
default="tensorboard",
help=(
'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
),
)
parser.add_argument(
"--mixed_precision",
type=str,
default=None,
choices=["no", "fp16", "bf16"],
help=(
"Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
if input_args is not None:
args = parser.parse_args(input_args)
else:
args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
if env_local_rank != -1 and env_local_rank != args.local_rank:
args.local_rank = env_local_rank
if args.with_prior_preservation:
if args.class_data_dir is None:
raise ValueError("You must specify a data directory for class images.")
if args.class_prompt is None:
raise ValueError("You must specify prompt for class images.")
else:
# logger is not available yet
if args.class_data_dir is not None:
warnings.warn("You need not use --class_data_dir without --with_prior_preservation.")
if args.class_prompt is not None:
warnings.warn("You need not use --class_prompt without --with_prior_preservation.")
return args
class DreamBoothDataset(Dataset):
"""
A dataset to prepare the instance and class images with the prompts for fine-tuning the model.
It pre-processes the images and the tokenizes prompts.
"""
def __init__(
self,
instance_data_root,
instance_prompt,
tokenizer,
class_data_root=None,
class_prompt=None,
size=512,
center_crop=False,
):
self.size = size
self.center_crop = center_crop
self.tokenizer = tokenizer
self.instance_data_root = Path(instance_data_root)
if not self.instance_data_root.exists():
raise ValueError("Instance images root doesn't exists.")
self.instance_images_path = list(Path(instance_data_root).iterdir())
self.num_instance_images = len(self.instance_images_path)
self.instance_prompt = instance_prompt
self._length = self.num_instance_images
if class_data_root is not None:
self.class_data_root = Path(class_data_root)
self.class_data_root.mkdir(parents=True, exist_ok=True)
self.class_images_path = list(self.class_data_root.iterdir())
self.num_class_images = len(self.class_images_path)
self._length = max(self.num_class_images, self.num_instance_images)
self.class_prompt = class_prompt
else:
self.class_data_root = None
self.image_transforms = transforms.Compose(
[
transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)
def __len__(self):
return self._length
def __getitem__(self, index):
example = {}
instance_image = Image.open(self.instance_images_path[index % self.num_instance_images])
if not instance_image.mode == "RGB":
instance_image = instance_image.convert("RGB")
example["instance_images"] = self.image_transforms(instance_image)
example["instance_prompt_ids"] = self.tokenizer(
self.instance_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
if self.class_data_root:
class_image = Image.open(self.class_images_path[index % self.num_class_images])
if not class_image.mode == "RGB":
class_image = class_image.convert("RGB")
example["class_images"] = self.image_transforms(class_image)
example["class_prompt_ids"] = self.tokenizer(
self.class_prompt,
truncation=True,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
).input_ids
return example
def collate_fn(examples, with_prior_preservation=False):
input_ids = [example["instance_prompt_ids"] for example in examples]
pixel_values = [example["instance_images"] for example in examples]
# Concat class and instance examples for prior preservation.
# We do this to avoid doing two forward passes.
if with_prior_preservation:
input_ids += [example["class_prompt_ids"] for example in examples]
pixel_values += [example["class_images"] for example in examples]
pixel_values = torch.stack(pixel_values)
pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = torch.cat(input_ids, dim=0)
batch = {
"input_ids": input_ids,
"pixel_values": pixel_values,
}
return batch
class PromptDataset(Dataset):
"A simple dataset to prepare the prompts to generate class images on multiple GPUs."
def __init__(self, prompt, num_samples):
self.prompt = prompt
self.num_samples = num_samples
def __len__(self):
return self.num_samples
def __getitem__(self, index):
example = {}
example["prompt"] = self.prompt
example["index"] = index
return example
def get_full_repo_name(model_id: str, organization: Optional[str] = None, token: Optional[str] = None):
if token is None:
token = HfFolder.get_token()
if organization is None:
username = whoami(token)["name"]
return f"{username}/{model_id}"
else:
return f"{organization}/{model_id}"
def main(args):
logging_dir = Path(args.output_dir, args.logging_dir)
accelerator = Accelerator(
gradient_accumulation_steps=args.gradient_accumulation_steps,
mixed_precision=args.mixed_precision,
log_with=args.report_to,
logging_dir=logging_dir,
)
if args.report_to == "wandb":
if not is_wandb_available():
raise ImportError("Make sure to install wandb if you want to use it for logging during training.")
import wandb
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger.info(accelerator.state, main_process_only=False)
if accelerator.is_local_main_process:
datasets.utils.logging.set_verbosity_warning()
transformers.utils.logging.set_verbosity_warning()
diffusers.utils.logging.set_verbosity_info()
else:
datasets.utils.logging.set_verbosity_error()
transformers.utils.logging.set_verbosity_error()
diffusers.utils.logging.set_verbosity_error()
# If passed along, set the training seed now.
if args.seed is not None:
set_seed(args.seed)
# Generate class images if prior preservation is enabled.
if args.with_prior_preservation:
class_images_dir = Path(args.class_data_dir)
if not class_images_dir.exists():
class_images_dir.mkdir(parents=True)
cur_class_images = len(list(class_images_dir.iterdir()))
if cur_class_images < args.num_class_images:
torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
if args.prior_generation_precision == "fp32":
torch_dtype = torch.float32
elif args.prior_generation_precision == "fp16":
torch_dtype = torch.float16
elif args.prior_generation_precision == "bf16":
torch_dtype = torch.bfloat16
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
torch_dtype=torch_dtype,
safety_checker=None,
revision=args.revision,
)
pipeline.set_progress_bar_config(disable=True)
num_new_images = args.num_class_images - cur_class_images
logger.info(f"Number of class images to sample: {num_new_images}.")
sample_dataset = PromptDataset(args.class_prompt, num_new_images)
sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
sample_dataloader = accelerator.prepare(sample_dataloader)
pipeline.to(accelerator.device)
for example in tqdm(
sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
):
images = pipeline(example["prompt"]).images
for i, image in enumerate(images):
hash_image = hashlib.sha1(image.tobytes()).hexdigest()
image_filename = class_images_dir / f"{example['index'][i] + cur_class_images}-{hash_image}.jpg"
image.save(image_filename)
del pipeline
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Handle the repository creation
if accelerator.is_main_process:
if args.push_to_hub:
if args.hub_model_id is None:
repo_name = get_full_repo_name(Path(args.output_dir).name, token=args.hub_token)
else:
repo_name = args.hub_model_id
repo_name = create_repo(repo_name, exist_ok=True)
repo = Repository(args.output_dir, clone_from=repo_name)
with open(os.path.join(args.output_dir, ".gitignore"), "w+") as gitignore:
if "step_*" not in gitignore:
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)
# Load the tokenizer
if args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
elif args.pretrained_model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(
args.pretrained_model_name_or_path,
subfolder="tokenizer",
revision=args.revision,
use_fast=False,
)
# import correct text encoder class
text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
# Load scheduler and models
noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
text_encoder = text_encoder_cls.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.revision
)
# We only train the additional adapter LoRA layers
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)
# For mixed precision training we cast the text_encoder and vae weights to half-precision
# as these models are only used for inference, keeping weights in full precision is not required.
weight_dtype = torch.float32
if accelerator.mixed_precision == "fp16":
weight_dtype = torch.float16
elif accelerator.mixed_precision == "bf16":
weight_dtype = torch.bfloat16
# Move unet, vae and text_encoder to device and cast to weight_dtype
unet.to(accelerator.device, dtype=weight_dtype)
vae.to(accelerator.device, dtype=weight_dtype)
text_encoder.to(accelerator.device, dtype=weight_dtype)
if args.enable_xformers_memory_efficient_attention:
if is_xformers_available():
unet.enable_xformers_memory_efficient_attention()
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")
# now we will add new LoRA weights to the attention layers
# It's important to realize here how many attention weights will be added and of which sizes
# The sizes of the attention layers consist only of two different variables:
# 1) - the "hidden_size", which is increased according to `unet.config.block_out_channels`.
# 2) - the "cross attention size", which is set to `unet.config.cross_attention_dim`.
# Let's first see how many attention processors we will have to set.
# For Stable Diffusion, it should be equal to:
# - down blocks (2x attention layers) * (2x transformer layers) * (3x down blocks) = 12
# - mid blocks (2x attention layers) * (1x transformer layers) * (1x mid blocks) = 2
# - up blocks (2x attention layers) * (3x transformer layers) * (3x down blocks) = 18
# => 32 layers
# Set correct lora layers
lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(lora_layers)
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Enable TF32 for faster training on Ampere GPUs,
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
if args.allow_tf32:
torch.backends.cuda.matmul.allow_tf32 = True
if args.scale_lr:
args.learning_rate = (
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)
# Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
if args.use_8bit_adam:
try:
import bitsandbytes as bnb
except ImportError:
raise ImportError(
"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
)
optimizer_class = bnb.optim.AdamW8bit
else:
optimizer_class = torch.optim.AdamW
# Optimizer creation
optimizer = optimizer_class(
lora_layers.parameters(),
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
eps=args.adam_epsilon,
)
# Dataset and DataLoaders creation:
train_dataset = DreamBoothDataset(
instance_data_root=args.instance_data_dir,
instance_prompt=args.instance_prompt,
class_data_root=args.class_data_dir if args.with_prior_preservation else None,
class_prompt=args.class_prompt,
tokenizer=tokenizer,
size=args.resolution,
center_crop=args.center_crop,
)
train_dataloader = torch.utils.data.DataLoader(
train_dataset,
batch_size=args.train_batch_size,
shuffle=True,
collate_fn=lambda examples: collate_fn(examples, args.with_prior_preservation),
num_workers=1,
)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
num_cycles=args.lr_num_cycles,
power=args.lr_power,
)
# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if overrode_max_train_steps:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
# Afterwards we recalculate our number of training epochs
args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("dreambooth-lora", config=vars(args))
# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(f" Num examples = {len(train_dataset)}")
logger.info(f" Num batches each epoch = {len(train_dataloader)}")
logger.info(f" Num Epochs = {args.num_train_epochs}")
logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
logger.info(f" Total optimization steps = {args.max_train_steps}")
global_step = 0
first_epoch = 0
# Potentially load in the weights and states from a previous save
if args.resume_from_checkpoint:
if args.resume_from_checkpoint != "latest":
path = os.path.basename(args.resume_from_checkpoint)
else:
# Get the mos recent checkpoint
dirs = os.listdir(args.output_dir)
dirs = [d for d in dirs if d.startswith("checkpoint")]
dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
path = dirs[-1]
accelerator.print(f"Resuming from checkpoint {path}")
accelerator.load_state(os.path.join(args.output_dir, path))
global_step = int(path.split("-")[1])
resume_global_step = global_step * args.gradient_accumulation_steps
first_epoch = resume_global_step // num_update_steps_per_epoch
resume_step = resume_global_step % num_update_steps_per_epoch
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
progress_bar.set_description("Steps")
for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
if step % args.gradient_accumulation_steps == 0:
progress_bar.update(1)
continue
with accelerator.accumulate(unet):
# Convert images to latent space
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise that we'll add to the latents
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch["input_ids"])[0]
# Predict the noise residual
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
# Get the target for loss depending on the prediction type
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
target = noise_scheduler.get_velocity(latents, noise, timesteps)
else:
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
if args.with_prior_preservation:
# Chunk the noise and model_pred into two parts and compute the loss on each part separately.
model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
target, target_prior = torch.chunk(target, 2, dim=0)
# Compute instance loss
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
# Compute prior loss
prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
# Add the prior loss to the instance loss.
loss = loss + args.prior_loss_weight * prior_loss
else:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
logger.info(f"Saved state to {save_path}")
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)
if global_step >= args.max_train_steps:
break
if args.validation_prompt is not None and epoch % args.validation_epochs == 0:
logger.info(
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
f" {args.validation_prompt}."
)
# create pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
unet=accelerator.unwrap_model(unet),
text_encoder=accelerator.unwrap_model(text_encoder),
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}")
for i, image in enumerate(images)
]
}
)
del pipeline
torch.cuda.empty_cache()
# Save the lora layers
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
# Final inference
# Load previous pipeline
pipeline = DiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, revision=args.revision, torch_dtype=weight_dtype
)
pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config)
pipeline = pipeline.to(accelerator.device)
# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
# run inference
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)
prompt = args.num_validation_images * [args.validation_prompt]
images = pipeline(prompt, num_inference_steps=25, generator=generator).images
for tracker in accelerator.trackers:
if tracker.name == "wandb":
tracker.log(
{
"test": [
wandb.Image(image, caption=f"{i}: {args.validation_prompt}") for i, image in enumerate(images)
]
}
)
accelerator.end_training()
if __name__ == "__main__":
args = parse_args()
main(args)

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import copy
import logging

View File

@@ -1,3 +1,18 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
import argparse
import logging
import math

View File

@@ -1,4 +1,4 @@
# Copyright 2020 The HuggingFace Team. All rights reserved.
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

243
src/diffusers/loaders.py Normal file
View File

@@ -0,0 +1,243 @@
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from collections import defaultdict
from typing import Callable, Dict, Union
import torch
from .models.cross_attention import LoRACrossAttnProcessor
from .models.modeling_utils import _get_model_file
from .utils import DIFFUSERS_CACHE, HF_HUB_OFFLINE, logging
logger = logging.get_logger(__name__)
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
class AttnProcsLayers(torch.nn.Module):
def __init__(self, state_dict: Dict[str, torch.Tensor]):
super().__init__()
self.layers = torch.nn.ModuleList(state_dict.values())
self.mapping = {k: v for k, v in enumerate(state_dict.keys())}
self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
# we add a hook to state_dict() and load_state_dict() so that the
# naming fits with `unet.attn_processors`
def map_to(module, state_dict, *args, **kwargs):
new_state_dict = {}
for key, value in state_dict.items():
num = int(key.split(".")[1]) # 0 is always "layers"
new_key = key.replace(f"layers.{num}", module.mapping[num])
new_state_dict[new_key] = value
return new_state_dict
def map_from(module, state_dict, *args, **kwargs):
all_keys = list(state_dict.keys())
for key in all_keys:
replace_key = key.split(".processor")[0] + ".processor"
new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
state_dict[new_key] = state_dict[key]
del state_dict[key]
self._register_state_dict_hook(map_to)
self._register_load_state_dict_pre_hook(map_from, with_module=True)
class UNet2DConditionLoadersMixin:
def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
r"""
Load pretrained attention processor layers into `UNet2DConditionModel`. Attention processor layers have to be
defined in
[cross_attention.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py)
and be a `torch.nn.Module` class.
<Tip warning={true}>
This function is experimental and might change in the future
</Tip>
Parameters:
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
Can be either:
- A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
- A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
`./my_model_directory/`.
- A [torch state
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
cache_dir (`Union[str, os.PathLike]`, *optional*):
Path to a directory in which a downloaded pretrained model configuration should be cached if the
standard cache should not be used.
force_download (`bool`, *optional*, defaults to `False`):
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
cached versions if they exist.
resume_download (`bool`, *optional*, defaults to `False`):
Whether or not to delete incompletely received files. Will attempt to resume the download if such a
file exists.
proxies (`Dict[str, str]`, *optional*):
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
local_files_only(`bool`, *optional*, defaults to `False`):
Whether or not to only look at local files (i.e., do not try to download the model).
use_auth_token (`str` or *bool*, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `diffusers-cli login` (stored in `~/.huggingface`).
revision (`str`, *optional*, defaults to `"main"`):
The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
identifier allowed by git.
subfolder (`str`, *optional*, defaults to `""`):
In case the relevant files are located inside a subfolder of the model repo (either remote in
huggingface.co or downloaded locally), you can specify the folder name here.
mirror (`str`, *optional*):
Mirror source to accelerate downloads in China. If you are from China and have an accessibility
problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
Please refer to the mirror site for more information.
<Tip>
It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
models](https://huggingface.co/docs/hub/models-gated#gated-models).
</Tip>
<Tip>
Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
this method in a firewalled environment.
</Tip>
"""
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
force_download = kwargs.pop("force_download", False)
resume_download = kwargs.pop("resume_download", False)
proxies = kwargs.pop("proxies", None)
local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
use_auth_token = kwargs.pop("use_auth_token", None)
revision = kwargs.pop("revision", None)
subfolder = kwargs.pop("subfolder", None)
weight_name = kwargs.pop("weight_name", LORA_WEIGHT_NAME)
user_agent = {
"file_type": "attn_procs_weights",
"framework": "pytorch",
}
if not isinstance(pretrained_model_name_or_path_or_dict, dict):
model_file = _get_model_file(
pretrained_model_name_or_path_or_dict,
weights_name=weight_name,
cache_dir=cache_dir,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
revision=revision,
subfolder=subfolder,
user_agent=user_agent,
)
state_dict = torch.load(model_file, map_location="cpu")
else:
state_dict = pretrained_model_name_or_path_or_dict
# fill attn processors
attn_processors = {}
is_lora = all("lora" in k for k in state_dict.keys())
if is_lora:
lora_grouped_dict = defaultdict(dict)
for key, value in state_dict.items():
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
lora_grouped_dict[attn_processor_key][sub_key] = value
for key, value_dict in lora_grouped_dict.items():
rank = value_dict["to_k_lora.down.weight"].shape[0]
cross_attention_dim = value_dict["to_k_lora.down.weight"].shape[1]
hidden_size = value_dict["to_k_lora.up.weight"].shape[0]
attn_processors[key] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, rank=rank
)
attn_processors[key].load_state_dict(value_dict)
else:
raise ValueError(f"{model_file} does not seem to be in the correct format expected by LoRA training.")
# set correct dtype & device
attn_processors = {k: v.to(device=self.device, dtype=self.dtype) for k, v in attn_processors.items()}
# set layers
self.set_attn_processor(attn_processors)
def save_attn_procs(
self,
save_directory: Union[str, os.PathLike],
is_main_process: bool = True,
weights_name: str = LORA_WEIGHT_NAME,
save_function: Callable = None,
):
r"""
Save an attention procesor to a directory, so that it can be re-loaded using the
`[`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`]` method.
Arguments:
save_directory (`str` or `os.PathLike`):
Directory to which to save. Will be created if it doesn't exist.
is_main_process (`bool`, *optional*, defaults to `True`):
Whether the process calling this is the main process or not. Useful when in distributed training like
TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
the main process to avoid race conditions.
save_function (`Callable`):
The function to use to save the state dictionary. Useful on distributed training like TPUs when one
need to replace `torch.save` by another method. Can be configured with the environment variable
`DIFFUSERS_SAVE_MODE`.
"""
if os.path.isfile(save_directory):
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
return
if save_function is None:
save_function = torch.save
os.makedirs(save_directory, exist_ok=True)
model_to_save = AttnProcsLayers(self.attn_processors)
# Save the model
state_dict = model_to_save.state_dict()
# Clean the folder from a previous save
for filename in os.listdir(save_directory):
full_filename = os.path.join(save_directory, filename)
# If we have a shard file that is not going to be replaced, we delete it, but only from the main process
# in distributed settings to avoid race conditions.
weights_no_suffix = weights_name.replace(".bin", "")
if filename.startswith(weights_no_suffix) and os.path.isfile(full_filename) and is_main_process:
os.remove(full_filename)
# Save the model
save_function(state_dict, os.path.join(save_directory, weights_name))
logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")

View File

@@ -246,6 +246,68 @@ class CrossAttnProcessor:
return hidden_states
class LoRALinearLayer(nn.Module):
def __init__(self, in_features, out_features, rank=4):
super().__init__()
if rank > min(in_features, out_features):
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
self.down = nn.Linear(in_features, rank, bias=False)
self.up = nn.Linear(rank, out_features, bias=False)
self.scale = 1.0
nn.init.normal_(self.down.weight, std=1 / rank)
nn.init.zeros_(self.up.weight)
def forward(self, hidden_states):
orig_dtype = hidden_states.dtype
dtype = self.down.weight.dtype
down_hidden_states = self.down(hidden_states.to(dtype))
up_hidden_states = self.up(down_hidden_states)
return up_hidden_states.to(orig_dtype)
class LoRACrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class CrossAttnAddedKVProcessor:
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
residual = hidden_states
@@ -312,6 +374,41 @@ class XFormersCrossAttnProcessor:
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states
class LoRAXFormersCrossAttnProcessor(nn.Module):
def __init__(self, hidden_size, cross_attention_dim, rank=4):
super().__init__()
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
def __call__(
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
query = attn.head_to_batch_dim(query).contiguous()
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
# linear proj
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
return hidden_states

View File

@@ -438,7 +438,7 @@ class ModelMixin(torch.nn.Module):
model_file = None
if from_flax:
model_file = cls._get_model_file(
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=FLAX_WEIGHTS_NAME,
cache_dir=cache_dir,
@@ -474,7 +474,7 @@ class ModelMixin(torch.nn.Module):
else:
if is_safetensors_available():
try:
model_file = cls._get_model_file(
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=SAFETENSORS_WEIGHTS_NAME,
cache_dir=cache_dir,
@@ -490,7 +490,7 @@ class ModelMixin(torch.nn.Module):
except:
pass
if model_file is None:
model_file = cls._get_model_file(
model_file = _get_model_file(
pretrained_model_name_or_path,
weights_name=WEIGHTS_NAME,
cache_dir=cache_dir,
@@ -599,92 +599,6 @@ class ModelMixin(torch.nn.Module):
return model
@classmethod
def _get_model_file(
cls,
pretrained_model_name_or_path,
*,
weights_name,
subfolder,
cache_dir,
force_download,
proxies,
resume_download,
local_files_only,
use_auth_token,
user_agent,
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)
elif subfolder is not None and os.path.isfile(
os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
):
model_file = os.path.join(pretrained_model_name_or_path, subfolder, weights_name)
else:
raise EnvironmentError(
f"Error no file named {weights_name} found in directory {pretrained_model_name_or_path}."
)
return model_file
else:
try:
# Load from URL or cache if already cached
model_file = hf_hub_download(
pretrained_model_name_or_path,
filename=weights_name,
cache_dir=cache_dir,
force_download=force_download,
proxies=proxies,
resume_download=resume_download,
local_files_only=local_files_only,
use_auth_token=use_auth_token,
user_agent=user_agent,
subfolder=subfolder,
revision=revision,
)
return model_file
except RepositoryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
"token having permission to this repo with `use_auth_token` or log in with `huggingface-cli "
"login`."
)
except RevisionNotFoundError:
raise EnvironmentError(
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for "
"this model name. Check the model page at "
f"'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
)
except EntryNotFoundError:
raise EnvironmentError(
f"{pretrained_model_name_or_path} does not appear to have a file named {weights_name}."
)
except HTTPError as err:
raise EnvironmentError(
"There was a specific connection error when trying to load"
f" {pretrained_model_name_or_path}:\n{err}"
)
except ValueError:
raise EnvironmentError(
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
f" directory containing a file named {weights_name} or"
" \nCheckout your internet connection or see how to run the library in"
" offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
)
except EnvironmentError:
raise EnvironmentError(
f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it from "
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
f"containing a file named {weights_name}"
)
@classmethod
def _load_pretrained_model(
cls,
@@ -848,7 +762,9 @@ def _get_model_file(
revision,
):
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
if os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(pretrained_model_name_or_path):
return pretrained_model_name_or_path
elif os.path.isdir(pretrained_model_name_or_path):
if os.path.isfile(os.path.join(pretrained_model_name_or_path, weights_name)):
# Load from a PyTorch checkpoint
model_file = os.path.join(pretrained_model_name_or_path, weights_name)

View File

@@ -19,6 +19,7 @@ import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import UNet2DConditionLoadersMixin
from ..utils import BaseOutput, logging
from .cross_attention import AttnProcessor
from .embeddings import TimestepEmbedding, Timesteps
@@ -49,7 +50,7 @@ class UNet2DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet2DConditionModel(ModelMixin, ConfigMixin):
class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
r"""
UNet2DConditionModel is a conditional 2D UNet model that takes in a noisy sample, conditional state, and a timestep
and returns sample shaped output.
@@ -266,17 +267,59 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor):
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
def fn_recursive_attn_processor(module: torch.nn.Module):
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
if hasattr(module, "set_processor"):
module.set_processor(processor)
processors[f"{name}.processor"] = module.processor
for child in module.children():
fn_recursive_attn_processor(child)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
for module in self.children():
fn_recursive_attn_processor(module)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -353,17 +353,59 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.conv_act = nn.SiLU()
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
def set_attn_processor(self, processor: AttnProcessor):
@property
def attn_processors(self) -> Dict[str, AttnProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
def fn_recursive_attn_processor(module: torch.nn.Module):
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttnProcessor]):
if hasattr(module, "set_processor"):
module.set_processor(processor)
processors[f"{name}.processor"] = module.processor
for child in module.children():
fn_recursive_attn_processor(child)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
for module in self.children():
fn_recursive_attn_processor(module)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttnProcessor, Dict[str, AttnProcessor]]):
r"""
Parameters:
`processor (`dict` of `AttnProcessor` or `AttnProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
of **all** `CrossAttention` layers.
In case `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainablae attention processors.:
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_attention_slice(self, slice_size):
r"""

View File

@@ -58,6 +58,7 @@ from .import_utils import (
is_transformers_available,
is_transformers_version,
is_unidecode_available,
is_wandb_available,
is_xformers_available,
requires_backends,
)

View File

@@ -217,6 +217,13 @@ try:
except importlib_metadata.PackageNotFoundError:
_k_diffusion_available = False
_wandb_available = importlib.util.find_spec("wandb") is not None
try:
_wandb_version = importlib_metadata.version("wandb")
logger.debug(f"Successfully imported k-diffusion version {_wandb_version }")
except importlib_metadata.PackageNotFoundError:
_wandb_available = False
def is_torch_available():
return _torch_available
@@ -274,6 +281,10 @@ def is_k_diffusion_available():
return _k_diffusion_available
def is_wandb_available():
return _wandb_available
# docstyle-ignore
FLAX_IMPORT_ERROR = """
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -328,6 +339,12 @@ K_DIFFUSION_IMPORT_ERROR = """
install k-diffusion`
"""
# docstyle-ignore
WANDB_IMPORT_ERROR = """
{0} requires the wandb library but it was not found in your environment. You can install it with pip: `pip
install wandb`
"""
BACKENDS_MAPPING = OrderedDict(
[
@@ -340,6 +357,7 @@ BACKENDS_MAPPING = OrderedDict(
("unidecode", (is_unidecode_available, UNIDECODE_IMPORT_ERROR)),
("librosa", (is_librosa_available, LIBROSA_IMPORT_ERROR)),
("k_diffusion", (is_k_diffusion_available, K_DIFFUSION_IMPORT_ERROR)),
("wandb", (is_wandb_available, WANDB_IMPORT_ERROR)),
]
)

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 Optuna, Hugging Face
# Copyright 2022 Optuna, Hugging Face
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -20,18 +20,8 @@ import unittest
import torch
from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
logging,
require_torch_gpu,
slow,
torch_all_close,
torch_device,
)
from diffusers.utils.import_utils import is_xformers_available
from parameterized import parameterized
from diffusers import UNet2DModel
from diffusers.utils import floats_tensor, logging, slow, torch_all_close, torch_device
from ..test_modeling_common import ModelTesterMixin
@@ -218,237 +208,6 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-3))
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model.set_attention_slice("auto")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_model_slicable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
def check_slicable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)
for child in module.children():
check_slicable_dim_attr(child)
# retrieve number of attention layers
for module in model.children():
check_slicable_dim_attr(module)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(num))
self.is_run = False
self.number = 0
self.counter = 0
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states += self.weight
self.is_run = True
self.counter += 1
self.number = number
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
processor = AttnEasyProc(5.0)
model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
assert processor.counter == 12
assert processor.is_run
assert processor.number == 123
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DModel
@@ -564,310 +323,3 @@ class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
def test_forward_with_norm_groups(self):
# not required for this model
pass
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
revision = "fp16" if fp16 else None
torch_dtype = torch.float16 if fp16 else torch.float32
model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
)
model.to(torch_device).eval()
return model
def test_set_attention_slice_auto(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("auto")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_max(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("max")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_int(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice(2)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_list(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# there are 32 slicable layers
slice_list = 16 * [2, 3]
unet = self.get_unet_model()
unet.set_attention_slice(slice_list)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return hidden_states
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == (4, 4, 64, 64)
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == (4, 4, 64, 64)
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
@require_torch_gpu
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

View File

@@ -0,0 +1,688 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import tempfile
import unittest
import torch
from diffusers import UNet2DConditionModel
from diffusers.models.cross_attention import LoRACrossAttnProcessor
from diffusers.utils import (
floats_tensor,
load_hf_numpy,
logging,
require_torch_gpu,
slow,
torch_all_close,
torch_device,
)
from diffusers.utils.import_utils import is_xformers_available
from parameterized import parameterized
from ..test_modeling_common import ModelTesterMixin
logger = logging.get_logger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
@unittest.skipIf(
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_enable_works(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.enable_xformers_memory_efficient_attention()
assert (
model.mid_block.attentions[0].transformer_blocks[0].attn1._use_memory_efficient_attention_xformers
), "xformers is not enabled"
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
def test_gradient_checkpointing(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
assert not model.is_gradient_checkpointing and model.training
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
labels = torch.randn_like(out)
loss = (out - labels).mean()
loss.backward()
# re-instantiate the model now enabling gradient checkpointing
model_2 = self.model_class(**init_dict)
# clone model
model_2.load_state_dict(model.state_dict())
model_2.to(torch_device)
model_2.enable_gradient_checkpointing()
assert model_2.is_gradient_checkpointing and model_2.training
out_2 = model_2(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model_2.zero_grad()
loss_2 = (out_2 - labels).mean()
loss_2.backward()
# compare the output and parameters gradients
self.assertTrue((loss - loss_2).abs() < 1e-5)
named_params = dict(model.named_parameters())
named_params_2 = dict(model_2.named_parameters())
for name, param in named_params.items():
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
def test_model_with_attention_head_dim_tuple(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_with_use_linear_projection(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["use_linear_projection"] = True
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
with torch.no_grad():
output = model(**inputs_dict)
if isinstance(output, dict):
output = output.sample
self.assertIsNotNone(output)
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
def test_model_attention_slicing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
model.eval()
model.set_attention_slice("auto")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice("max")
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
model.set_attention_slice(2)
with torch.no_grad():
output = model(**inputs_dict)
assert output is not None
def test_model_slicable_head_dim(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
def check_slicable_dim_attr(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
assert isinstance(module.sliceable_head_dim, int)
for child in module.children():
check_slicable_dim_attr(child)
# retrieve number of attention layers
for module in model.children():
check_slicable_dim_attr(module)
def test_special_attn_proc(self):
class AttnEasyProc(torch.nn.Module):
def __init__(self, num):
super().__init__()
self.weight = torch.nn.Parameter(torch.tensor(num))
self.is_run = False
self.number = 0
self.counter = 0
def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None, number=None):
batch_size, sequence_length, _ = hidden_states.shape
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
query = attn.to_q(hidden_states)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
hidden_states += self.weight
self.is_run = True
self.counter += 1
self.number = number
return hidden_states
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
processor = AttnEasyProc(5.0)
model.set_attn_processor(processor)
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
assert processor.counter == 12
assert processor.is_run
assert processor.number == 123
def test_lora_processors(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
sample1 = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
# make sure we can set a list of attention processors
model.set_attn_processor(lora_attn_procs)
model.to(torch_device)
# test that attn processors can be set to itself
model.set_attn_processor(model.attn_processors)
with torch.no_grad():
sample2 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.0}).sample
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample1 - sample2).abs().max() < 1e-4
assert (sample3 - sample4).abs().max() < 1e-4
# sample 2 and sample 3 should be different
assert (sample2 - sample3).abs().max() > 1e-4
def test_lora_save_load(self):
# enable deterministic behavior for gradient checkpointing
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict["attention_head_dim"] = (8, 16)
torch.manual_seed(0)
model = self.model_class(**init_dict)
model.to(torch_device)
with torch.no_grad():
old_sample = model(**inputs_dict).sample
lora_attn_procs = {}
for name in model.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = model.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(model.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = model.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRACrossAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
lora_attn_procs[name] = lora_attn_procs[name].to(model.device)
# add 1 to weights to mock trained weights
with torch.no_grad():
lora_attn_procs[name].to_q_lora.up.weight += 1
lora_attn_procs[name].to_k_lora.up.weight += 1
lora_attn_procs[name].to_v_lora.up.weight += 1
lora_attn_procs[name].to_out_lora.up.weight += 1
model.set_attn_processor(lora_attn_procs)
with torch.no_grad():
sample = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
with tempfile.TemporaryDirectory() as tmpdirname:
model.save_attn_procs(tmpdirname)
torch.manual_seed(0)
new_model = self.model_class(**init_dict)
new_model.to(torch_device)
new_model.load_attn_procs(tmpdirname)
with torch.no_grad():
new_sample = new_model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample
assert (sample - new_sample).abs().max() < 1e-4
# LoRA and no LoRA should NOT be the same
assert (sample - old_sample).abs().max() > 1e-4
@slow
class UNet2DConditionModelIntegrationTests(unittest.TestCase):
def get_file_format(self, seed, shape):
return f"gaussian_noise_s={seed}_shape={'_'.join([str(s) for s in shape])}.npy"
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def get_latents(self, seed=0, shape=(4, 4, 64, 64), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
image = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return image
def get_unet_model(self, fp16=False, model_id="CompVis/stable-diffusion-v1-4"):
revision = "fp16" if fp16 else None
torch_dtype = torch.float16 if fp16 else torch.float32
model = UNet2DConditionModel.from_pretrained(
model_id, subfolder="unet", torch_dtype=torch_dtype, revision=revision
)
model.to(torch_device).eval()
return model
def test_set_attention_slice_auto(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("auto")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_max(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice("max")
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_int(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
unet = self.get_unet_model()
unet.set_attention_slice(2)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def test_set_attention_slice_list(self):
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
# there are 32 slicable layers
slice_list = 16 * [2, 3]
unet = self.get_unet_model()
unet.set_attention_slice(slice_list)
latents = self.get_latents(33)
encoder_hidden_states = self.get_encoder_hidden_states(33)
timestep = 1
with torch.no_grad():
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
mem_bytes = torch.cuda.max_memory_allocated()
assert mem_bytes < 5 * 10**9
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
dtype = torch.float16 if fp16 else torch.float32
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
return hidden_states
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4424, 0.1510, -0.1937, 0.2118, 0.3746, -0.3957, 0.0160, -0.0435]],
[47, 0.55, [-0.1508, 0.0379, -0.3075, 0.2540, 0.3633, -0.0821, 0.1719, -0.0207]],
[21, 0.89, [-0.6479, 0.6364, -0.3464, 0.8697, 0.4443, -0.6289, -0.0091, 0.1778]],
[9, 1000, [0.8888, -0.5659, 0.5834, -0.7469, 1.1912, -0.3923, 1.1241, -0.4424]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2323, -0.1304, 0.0813, -0.3093, -0.0919, -0.1571, -0.1125, -0.5806]],
[17, 0.55, [-0.0831, -0.2443, 0.0901, -0.0919, 0.3396, 0.0103, -0.3743, 0.0701]],
[8, 0.89, [-0.4863, 0.0859, 0.0875, -0.1658, 0.9199, -0.0114, 0.4839, 0.4639]],
[3, 1000, [-0.5649, 0.2402, -0.5518, 0.1248, 1.1328, -0.2443, -0.0325, -1.0078]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_4_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="CompVis/stable-diffusion-v1-4", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.4430, 0.1570, -0.1867, 0.2376, 0.3205, -0.3681, 0.0525, -0.0722]],
[47, 0.55, [-0.1415, 0.0129, -0.3136, 0.2257, 0.3430, -0.0536, 0.2114, -0.0436]],
[21, 0.89, [-0.7091, 0.6664, -0.3643, 0.9032, 0.4499, -0.6541, 0.0139, 0.1750]],
[9, 1000, [0.8878, -0.5659, 0.5844, -0.7442, 1.1883, -0.3927, 1.1192, -0.4423]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5")
latents = self.get_latents(seed)
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.2695, -0.1669, 0.0073, -0.3181, -0.1187, -0.1676, -0.1395, -0.5972]],
[17, 0.55, [-0.1290, -0.2588, 0.0551, -0.0916, 0.3286, 0.0238, -0.3669, 0.0322]],
[8, 0.89, [-0.5283, 0.1198, 0.0870, -0.1141, 0.9189, -0.0150, 0.5474, 0.4319]],
[3, 1000, [-0.5601, 0.2411, -0.5435, 0.1268, 1.1338, -0.2427, -0.0280, -1.0020]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_v1_5_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-v1-5", fp16=True)
latents = self.get_latents(seed, fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[33, 4, [-0.7639, 0.0106, -0.1615, -0.3487, -0.0423, -0.7972, 0.0085, -0.4858]],
[47, 0.55, [-0.6564, 0.0795, -1.9026, -0.6258, 1.8235, 1.2056, 1.2169, 0.9073]],
[21, 0.89, [0.0327, 0.4399, -0.6358, 0.3417, 0.4120, -0.5621, -0.0397, -1.0430]],
[9, 1000, [0.1600, 0.7303, -1.0556, -0.3515, -0.7440, -1.2037, -1.8149, -1.8931]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting")
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
encoder_hidden_states = self.get_encoder_hidden_states(seed)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == (4, 4, 64, 64)
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [-0.1047, -1.7227, 0.1067, 0.0164, -0.5698, -0.4172, -0.1388, 1.1387]],
[17, 0.55, [0.0975, -0.2856, -0.3508, -0.4600, 0.3376, 0.2930, -0.2747, -0.7026]],
[8, 0.89, [-0.0952, 0.0183, -0.5825, -0.1981, 0.1131, 0.4668, -0.0395, -0.3486]],
[3, 1000, [0.4790, 0.4949, -1.0732, -0.7158, 0.7959, -0.9478, 0.1105, -0.9741]],
# fmt: on
]
)
@require_torch_gpu
def test_compvis_sd_inpaint_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="runwayml/stable-diffusion-inpainting", fp16=True)
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == (4, 4, 64, 64)
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)
@parameterized.expand(
[
# fmt: off
[83, 4, [0.1514, 0.0807, 0.1624, 0.1016, -0.1896, 0.0263, 0.0677, 0.2310]],
[17, 0.55, [0.1164, -0.0216, 0.0170, 0.1589, -0.3120, 0.1005, -0.0581, -0.1458]],
[8, 0.89, [-0.1758, -0.0169, 0.1004, -0.1411, 0.1312, 0.1103, -0.1996, 0.2139]],
[3, 1000, [0.1214, 0.0352, -0.0731, -0.1562, -0.0994, -0.0906, -0.2340, -0.0539]],
# fmt: on
]
)
@require_torch_gpu
def test_stabilityai_sd_v2_fp16(self, seed, timestep, expected_slice):
model = self.get_unet_model(model_id="stabilityai/stable-diffusion-2", fp16=True)
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
with torch.no_grad():
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
assert sample.shape == latents.shape
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu()
expected_output_slice = torch.tensor(expected_slice)
assert torch_all_close(output_slice, expected_output_slice, atol=5e-3)

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2021 The HuggingFace Inc. team.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@@ -1,5 +1,5 @@
# coding=utf-8
# Copyright 2020 The HuggingFace Inc. team.
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.