1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

VQ-diffusion (#658)

* Changes for VQ-diffusion VQVAE

Add specify dimension of embeddings to VQModel:
`VQModel` will by default set the dimension of embeddings to the number
of latent channels. The VQ-diffusion VQVAE has a smaller
embedding dimension, 128, than number of latent channels, 256.

Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down
unet block helpers. VQ-diffusion's VQVAE uses those two block types.

* Changes for VQ-diffusion transformer

Modify attention.py so SpatialTransformer can be used for
VQ-diffusion's transformer.

SpatialTransformer:
- Can now operate over discrete inputs (classes of vector embeddings) as well as continuous.
- `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs
- modified forward pass to take optional timestep embeddings

ImagePositionalEmbeddings:
- added to provide positional embeddings to discrete inputs for latent pixels

BasicTransformerBlock:
- norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings
- modified forward pass to take optional timestep embeddings

CrossAttention:
- now may optionally take a bias parameter for its query, key, and value linear layers

FeedForward:
- Internal layers are now configurable

ApproximateGELU:
- Activation function in VQ-diffusion's feedforward layer

AdaLayerNorm:
- Norm layer modified to incorporate timestep embeddings

* Add VQ-diffusion scheduler

* Add VQ-diffusion pipeline

* Add VQ-diffusion convert script to diffusers

* Add VQ-diffusion dummy objects

* Add VQ-diffusion markdown docs

* Add VQ-diffusion tests

* some renaming

* some fixes

* more renaming

* correct

* fix typo

* correct weights

* finalize

* fix tests

* Apply suggestions from code review

Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>

* Apply suggestions from code review

Co-authored-by: Pedro Cuenca <pedro@huggingface.co>

* finish

* finish

* up

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com>
Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
Will Berman
2022-11-03 08:10:28 -07:00
committed by GitHub
parent 269109dbfb
commit ef2ea33c3b
25 changed files with 2674 additions and 223 deletions

View File

@@ -96,6 +96,8 @@
title: "Stochastic Karras VE"
- local: api/pipelines/dance_diffusion
title: "Dance Diffusion"
- local: api/pipelines/vq_diffusion
title: "VQ Diffusion"
- local: api/pipelines/repaint
title: "RePaint"
title: "Pipelines"

View File

@@ -49,6 +49,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
## AutoencoderKL
[[autodoc]] AutoencoderKL
## Transformer2DModel
[[autodoc]] Transformer2DModel
## Transformer2DModelOutput
[[autodoc]] models.attention.Transformer2DModelOutput
## FlaxModelMixin
[[autodoc]] FlaxModelMixin

View File

@@ -41,22 +41,22 @@ If you are looking for *official* training examples, please have a look at [exam
The following table summarizes all officially supported pipelines, their corresponding paper, and if
available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------:|:---:|
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below.

View File

@@ -0,0 +1,34 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# VQDiffusion
## Overview
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
The abstract of the paper is the following:
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion).
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - |
## VQDiffusionPipeline
[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline
- __call__

View File

@@ -113,7 +113,6 @@ Score SDE-VP is under construction.
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
#### Euler scheduler
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
@@ -130,6 +129,12 @@ Fast scheduler which often times generates good outputs with 20-30 steps.
[[autodoc]] EulerAncestralDiscreteScheduler
#### VQDiffusionScheduler
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
[[autodoc]] VQDiffusionScheduler
#### RePaint scheduler
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
@@ -137,4 +142,4 @@ Intended for use with [`RePaintPipeline`].
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
[[autodoc]] RePaintScheduler
[[autodoc]] RePaintScheduler

View File

@@ -34,6 +34,7 @@ available a colab notebook to directly try them out.
| Pipeline | Paper | Tasks | Colab
|---|---|:---:|:---:|
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
@@ -45,5 +46,6 @@ available a colab notebook to directly try them out.
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.

View File

@@ -0,0 +1,885 @@
"""
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
It currently only supports porting the ITHQ dataset.
ITHQ dataset:
```sh
# From the root directory of diffusers.
# Download the VQVAE checkpoint
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
# Download the VQVAE config
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
# Download the main model checkpoint
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
# Download the main model config
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
# run the convert script
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
--checkpoint_path ./ithq_learnable.pth \
--original_config_file ./ithq.yaml \
--vqvae_checkpoint_path ./ithq_vqvae.pth \
--vqvae_original_config_file ./ithq_vqvae.yaml \
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
```
"""
import argparse
import tempfile
import torch
import yaml
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.models.attention import Transformer2DModel
from transformers import CLIPTextModel, CLIPTokenizer
from yaml.loader import FullLoader
try:
from omegaconf import OmegaConf
except ImportError:
raise ImportError(
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
" OmegaConf`."
)
# vqvae model
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
def vqvae_model_from_original_config(original_config):
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
original_config = original_config.params
original_encoder_config = original_config.encoder_config.params
original_decoder_config = original_config.decoder_config.params
in_channels = original_encoder_config.in_channels
out_channels = original_decoder_config.out_ch
down_block_types = get_down_block_types(original_encoder_config)
up_block_types = get_up_block_types(original_decoder_config)
assert original_encoder_config.ch == original_decoder_config.ch
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
block_out_channels = tuple(
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
)
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
layers_per_block = original_encoder_config.num_res_blocks
assert original_encoder_config.z_channels == original_decoder_config.z_channels
latent_channels = original_encoder_config.z_channels
num_vq_embeddings = original_config.n_embed
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
norm_num_groups = 32
e_dim = original_config.embed_dim
model = VQModel(
in_channels=in_channels,
out_channels=out_channels,
down_block_types=down_block_types,
up_block_types=up_block_types,
block_out_channels=block_out_channels,
layers_per_block=layers_per_block,
latent_channels=latent_channels,
num_vq_embeddings=num_vq_embeddings,
norm_num_groups=norm_num_groups,
vq_embed_dim=e_dim,
)
return model
def get_down_block_types(original_encoder_config):
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
num_resolutions = len(original_encoder_config.ch_mult)
resolution = coerce_resolution(original_encoder_config.resolution)
curr_res = resolution
down_block_types = []
for _ in range(num_resolutions):
if curr_res in attn_resolutions:
down_block_type = "AttnDownEncoderBlock2D"
else:
down_block_type = "DownEncoderBlock2D"
down_block_types.append(down_block_type)
curr_res = [r // 2 for r in curr_res]
return down_block_types
def get_up_block_types(original_decoder_config):
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
num_resolutions = len(original_decoder_config.ch_mult)
resolution = coerce_resolution(original_decoder_config.resolution)
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
up_block_types = []
for _ in reversed(range(num_resolutions)):
if curr_res in attn_resolutions:
up_block_type = "AttnUpDecoderBlock2D"
else:
up_block_type = "UpDecoderBlock2D"
up_block_types.append(up_block_type)
curr_res = [r * 2 for r in curr_res]
return up_block_types
def coerce_attn_resolutions(attn_resolutions):
attn_resolutions = OmegaConf.to_object(attn_resolutions)
attn_resolutions_ = []
for ar in attn_resolutions:
if isinstance(ar, (list, tuple)):
attn_resolutions_.append(list(ar))
else:
attn_resolutions_.append([ar, ar])
return attn_resolutions_
def coerce_resolution(resolution):
resolution = OmegaConf.to_object(resolution)
if isinstance(resolution, int):
resolution = [resolution, resolution] # H, W
elif isinstance(resolution, (tuple, list)):
resolution = list(resolution)
else:
raise ValueError("Unknown type of resolution:", resolution)
return resolution
# done vqvae model
# vqvae checkpoint
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
# quant_conv
diffusers_checkpoint.update(
{
"quant_conv.weight": checkpoint["quant_conv.weight"],
"quant_conv.bias": checkpoint["quant_conv.bias"],
}
)
# quantize
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
# post_quant_conv
diffusers_checkpoint.update(
{
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
}
)
# decoder
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
return diffusers_checkpoint
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# conv_in
diffusers_checkpoint.update(
{
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
}
)
# down_blocks
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
down_block_prefix = f"encoder.down.{down_block_idx}"
# resnets
for resnet_idx, resnet in enumerate(down_block.resnets):
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
# downsample
# do not include the downsample when on the last down block
# There is no downsample on the last down block
if down_block_idx != len(model.encoder.down_blocks) - 1:
# There's a single downsample in the original checkpoint but a list of downsamples
# in the diffusers model.
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
downsample_prefix = f"{down_block_prefix}.downsample.conv"
diffusers_checkpoint.update(
{
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
}
)
# attentions
if hasattr(down_block, "attentions"):
for attention_idx, _ in enumerate(down_block.attentions):
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=attention_prefix,
)
)
# mid block
# mid block attentions
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
attention_prefix = "encoder.mid.attn_1"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# mid block resnets
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
# the hardcoded prefixes to `block_` are 1 and 2
orig_resnet_idx = diffusers_resnet_idx + 1
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
diffusers_checkpoint.update(
{
# conv_norm_out
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
# conv_out
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
}
)
return diffusers_checkpoint
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
# conv in
diffusers_checkpoint.update(
{
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
}
)
# up_blocks
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
# resnets
for resnet_idx, resnet in enumerate(up_block.resnets):
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
# upsample
# there is no up sample on the last up block
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
# in the diffusers model.
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
downsample_prefix = f"{up_block_prefix}.upsample.conv"
diffusers_checkpoint.update(
{
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
}
)
# attentions
if hasattr(up_block, "attentions"):
for attention_idx, _ in enumerate(up_block.attentions):
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint,
diffusers_attention_prefix=diffusers_attention_prefix,
attention_prefix=attention_prefix,
)
)
# mid block
# mid block attentions
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
attention_prefix = "decoder.mid.attn_1"
diffusers_checkpoint.update(
vqvae_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# mid block resnets
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
# the hardcoded prefixes to `block_` are 1 and 2
orig_resnet_idx = diffusers_resnet_idx + 1
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
diffusers_checkpoint.update(
vqvae_resnet_to_diffusers_checkpoint(
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
)
)
diffusers_checkpoint.update(
{
# conv_norm_out
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
# conv_out
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
}
)
return diffusers_checkpoint
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
rv = {
# norm1
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
# conv1
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
# norm2
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
# conv2
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
}
if resnet.conv_shortcut is not None:
rv.update(
{
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
}
)
return rv
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
return {
# group_norm
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
# query
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
# key
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
# value
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
# proj_attn
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
:, :, 0, 0
],
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
}
# done vqvae checkpoint
# transformer model
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
def transformer_model_from_original_config(
original_diffusion_config, original_transformer_config, original_content_embedding_config
):
assert (
original_diffusion_config.target in PORTED_DIFFUSIONS
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
assert (
original_transformer_config.target in PORTED_TRANSFORMERS
), f"{original_transformer_config.target} has not yet been ported to diffusers."
assert (
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
original_diffusion_config = original_diffusion_config.params
original_transformer_config = original_transformer_config.params
original_content_embedding_config = original_content_embedding_config.params
inner_dim = original_transformer_config["n_embd"]
n_heads = original_transformer_config["n_head"]
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
# number of attention heads times the sequence length (the dimension) of a
# single head. We want to specify our attention blocks with those values
# specified separately
assert inner_dim % n_heads == 0
d_head = inner_dim // n_heads
depth = original_transformer_config["n_layer"]
context_dim = original_transformer_config["condition_dim"]
num_embed = original_content_embedding_config["num_embed"]
# the number of embeddings in the transformer includes the mask embedding.
# the content embedding (the vqvae) does not include the mask embedding.
num_embed = num_embed + 1
height = original_transformer_config["content_spatial_size"][0]
width = original_transformer_config["content_spatial_size"][1]
assert width == height, "width has to be equal to height"
dropout = original_transformer_config["resid_pdrop"]
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
model_kwargs = {
"attention_bias": True,
"cross_attention_dim": context_dim,
"attention_head_dim": d_head,
"num_layers": depth,
"dropout": dropout,
"num_attention_heads": n_heads,
"num_vector_embeds": num_embed,
"num_embeds_ada_norm": num_embeds_ada_norm,
"norm_num_groups": 32,
"sample_size": width,
"activation_fn": "geglu-approximate",
}
model = Transformer2DModel(**model_kwargs)
return model
# done transformer model
# transformer checkpoint
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
diffusers_checkpoint = {}
transformer_prefix = "transformer.transformer"
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
# DalleMaskImageEmbedding
diffusers_checkpoint.update(
{
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.emb.weight"
],
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.height_emb.weight"
],
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
f"{latent_image_embedding_prefix}.width_emb.weight"
],
}
)
# transformer blocks
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
# ada norm block
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
diffusers_checkpoint.update(
transformer_ada_norm_to_diffusers_checkpoint(
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
)
)
# attention block
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
attention_prefix = f"{transformer_block_prefix}.attn1"
diffusers_checkpoint.update(
transformer_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# ada norm block
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
diffusers_checkpoint.update(
transformer_ada_norm_to_diffusers_checkpoint(
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
)
)
# attention block
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
attention_prefix = f"{transformer_block_prefix}.attn2"
diffusers_checkpoint.update(
transformer_attention_to_diffusers_checkpoint(
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
)
)
# norm block
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
norm_block_prefix = f"{transformer_block_prefix}.ln2"
diffusers_checkpoint.update(
{
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
}
)
# feedforward block
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
feedforward_prefix = f"{transformer_block_prefix}.mlp"
diffusers_checkpoint.update(
transformer_feedforward_to_diffusers_checkpoint(
checkpoint,
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
feedforward_prefix=feedforward_prefix,
)
)
# to logits
diffusers_norm_out_prefix = "norm_out"
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
diffusers_checkpoint.update(
{
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
}
)
diffusers_out_prefix = "out"
out_prefix = f"{transformer_prefix}.to_logits.1"
diffusers_checkpoint.update(
{
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
}
)
return diffusers_checkpoint
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
return {
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
}
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
return {
# key
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
# query
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
# value
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
# linear out
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
}
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
return {
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
}
# done transformer checkpoint
def read_config_file(filename):
# The yaml file contains annotations that certain values should
# loaded as tuples. By default, OmegaConf will panic when reading
# these. Instead, we can manually read the yaml with the FullLoader and then
# construct the OmegaConf object.
with open(filename) as f:
original_config = yaml.load(f, FullLoader)
return OmegaConf.create(original_config)
# We take separate arguments for the vqvae because the ITHQ vqvae config file
# is separate from the config file for the rest of the model.
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--vqvae_checkpoint_path",
default=None,
type=str,
required=True,
help="Path to the vqvae checkpoint to convert.",
)
parser.add_argument(
"--vqvae_original_config_file",
default=None,
type=str,
required=True,
help="The YAML config file corresponding to the original architecture for the vqvae.",
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
)
parser.add_argument(
"--original_config_file",
default=None,
type=str,
required=True,
help="The YAML config file corresponding to the original architecture.",
)
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
parser.add_argument(
"--checkpoint_load_device",
default="cpu",
type=str,
required=False,
help="The device passed to `map_location` when loading checkpoints.",
)
# See link for how ema weights are always selected
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
parser.add_argument(
"--no_use_ema",
action="store_true",
required=False,
help=(
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
" it as the original VQ-Diffusion always uses the ema weights when loading models."
),
)
args = parser.parse_args()
use_ema = not args.no_use_ema
print(f"loading checkpoints to {args.checkpoint_load_device}")
checkpoint_map_location = torch.device(args.checkpoint_load_device)
# vqvae_model
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
with init_empty_weights():
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
del vqvae_diffusers_checkpoint
del vqvae_checkpoint
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
print("done loading vqvae")
# done vqvae_model
# transformer_model
print(
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
f" {use_ema}"
)
original_config = read_config_file(args.original_config_file).model
diffusion_config = original_config.params.diffusion_config
transformer_config = original_config.params.diffusion_config.params.transformer_config
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
if use_ema:
if "ema" in pre_checkpoint:
checkpoint = {}
for k, v in pre_checkpoint["model"].items():
checkpoint[k] = v
for k, v in pre_checkpoint["ema"].items():
# The ema weights are only used on the transformer. To mimic their key as if they came
# from the state_dict for the top level model, we prefix with an additional "transformer."
# See the source linked in the args.use_ema config for more information.
checkpoint[f"transformer.{k}"] = v
else:
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
checkpoint = pre_checkpoint["model"]
else:
checkpoint = pre_checkpoint["model"]
del pre_checkpoint
with init_empty_weights():
transformer_model = transformer_model_from_original_config(
diffusion_config, transformer_config, content_embedding_config
)
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
transformer_model, checkpoint
)
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
del diffusers_transformer_checkpoint
del checkpoint
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
print("done loading transformer")
# done transformer_model
# text encoder
print("loading CLIP text encoder")
clip_name = "openai/clip-vit-base-patch32"
# The original VQ-Diffusion specifies the pad value by the int used in the
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
# specifies the pad value via the token before it has been tokenized. The `!` pad
# token is the same as padding with the `0` pad value.
pad_token = "!"
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
text_encoder_model = CLIPTextModel.from_pretrained(
clip_name,
# `CLIPTextModel` does not support device_map="auto"
# device_map="auto"
)
print("done loading CLIP text encoder")
# done text encoder
# scheduler
scheduler_model = VQDiffusionScheduler(
# the scheduler has the same number of embeddings as the transformer
num_vec_classes=transformer_model.num_vector_embeds
)
# done scheduler
print(f"saving VQ diffusion model, path: {args.dump_path}")
pipe = VQDiffusionPipeline(
vqvae=vqvae_model,
transformer=transformer_model,
tokenizer=tokenizer_model,
text_encoder=text_encoder_model,
scheduler=scheduler_model,
)
pipe.save_pretrained(args.dump_path)
print("done writing VQ diffusion model")

View File

@@ -18,7 +18,7 @@ from .utils import logging
if is_torch_available():
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -38,6 +38,7 @@ if is_torch_available():
PNDMPipeline,
RePaintPipeline,
ScoreSdeVePipeline,
VQDiffusionPipeline,
)
from .schedulers import (
DDIMScheduler,
@@ -50,6 +51,7 @@ if is_torch_available():
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
else:

View File

@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .attention import Transformer2DModel
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel

View File

@@ -12,13 +12,30 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.utils.import_utils import is_xformers_available
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils.import_utils import is_xformers_available
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
for the unnoised latent pixels.
"""
sample: torch.FloatTensor
if is_xformers_available():
@@ -28,6 +45,186 @@ else:
xformers = None
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
embeddings) inputs.
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
transformer action. Finally, reshape to image.
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
classes of unnoised image.
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = in_channels is not None
self.is_input_vectorized = num_vector_embeds is not None
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized:
raise ValueError(
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
)
for d in range(num_layers)
]
)
# 4. Define output layers
if self.is_input_continuous:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
tensor.
"""
# 1. Input
if self.is_input_continuous:
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
# 3. Output
if self.is_input_continuous:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
@@ -36,19 +233,19 @@ class AttentionBlock(nn.Module):
Uses three q, k, v linear layers to compute attention.
Parameters:
channels (:obj:`int`): The number of channels in the input and output.
num_head_channels (:obj:`int`, *optional*):
channels (`int`): The number of channels in the input and output.
num_head_channels (`int`, *optional*):
The number of channels in each head. If None, then `num_heads` = 1.
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
def __init__(
self,
channels: int,
num_head_channels: Optional[int] = None,
num_groups: int = 32,
norm_num_groups: int = 32,
rescale_output_factor: float = 1.0,
eps: float = 1e-5,
):
@@ -57,7 +254,7 @@ class AttentionBlock(nn.Module):
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
self.num_head_size = num_head_channels
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
# define q,k,v as linear layers
self.query = nn.Linear(channels, channels)
@@ -113,107 +310,61 @@ class AttentionBlock(nn.Module):
return hidden_states
class SpatialTransformer(nn.Module):
"""
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
standard transformer action. Finally, reshape to image.
Parameters:
in_channels (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
"""
def __init__(
self,
in_channels: int,
n_heads: int,
d_head: int,
depth: int = 1,
dropout: float = 0.0,
num_groups: int = 32,
context_dim: Optional[int] = None,
):
super().__init__()
self.n_heads = n_heads
self.d_head = d_head
self.in_channels = in_channels
inner_dim = n_heads * d_head
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
for d in range(depth)
]
)
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def _set_attention_slice(self, slice_size):
for block in self.transformer_blocks:
block._set_attention_slice(slice_size)
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
for block in self.transformer_blocks:
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
def forward(self, hidden_states, context=None):
# note: if no context is given, cross-attention defaults to self-attention
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=context)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2)
hidden_states = self.proj_out(hidden_states)
return hidden_states + residual
class BasicTransformerBlock(nn.Module):
r"""
A basic Transformer block.
Parameters:
dim (:obj:`int`): The number of channels in the input and output.
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
d_head (:obj:`int`): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
dim (`int`): The number of channels in the input and output.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
attention_bias (:
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
"""
def __init__(
self,
dim: int,
n_heads: int,
d_head: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
context_dim: Optional[int] = None,
gated_ff: bool = True,
checkpoint: bool = True,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
):
super().__init__()
self.attn1 = CrossAttention(
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is a self-attention
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.attn2 = CrossAttention(
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
) # is self-attn if context is none
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
# layer norms
self.use_ada_layer_norm = num_embeds_ada_norm is not None
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim)
self.norm2 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim)
self.checkpoint = checkpoint
def _set_attention_slice(self, slice_size):
self.attn1._slice_size = slice_size
@@ -245,10 +396,22 @@ class BasicTransformerBlock(nn.Module):
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, context=None):
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
def forward(self, hidden_states, context=None, timestep=None):
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states) + hidden_states
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
@@ -257,20 +420,28 @@ class CrossAttention(nn.Module):
A cross attention layer.
Parameters:
query_dim (:obj:`int`): The number of channels in the query.
context_dim (:obj:`int`, *optional*):
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
bias (`bool`, *optional*, defaults to False):
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
"""
def __init__(
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
):
super().__init__()
inner_dim = dim_head * heads
context_dim = context_dim if context_dim is not None else query_dim
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.scale = dim_head**-0.5
self.heads = heads
@@ -280,9 +451,9 @@ class CrossAttention(nn.Module):
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
@@ -394,23 +565,33 @@ class FeedForward(nn.Module):
A feed-forward layer.
Parameters:
dim (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
dim (`int`): The number of channels in the input.
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
"""
def __init__(
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
self,
dim: int,
dim_out: Optional[int] = None,
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
self.net = nn.ModuleList([])
if activation_fn == "geglu":
geglu = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
geglu = ApproximateGELU(dim, inner_dim)
self.net = nn.ModuleList([])
# project in
self.net.append(GEGLU(dim, inner_dim))
self.net.append(geglu)
# project dropout
self.net.append(nn.Dropout(dropout))
# project out
@@ -428,8 +609,8 @@ class GEGLU(nn.Module):
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
Parameters:
dim_in (:obj:`int`): The number of channels in the input.
dim_out (:obj:`int`): The number of channels in the output.
dim_in (`int`): The number of channels in the input.
dim_out (`int`): The number of channels in the output.
"""
def __init__(self, dim_in: int, dim_out: int):
@@ -445,3 +626,38 @@ class GEGLU(nn.Module):
def forward(self, hidden_states):
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
return hidden_states * self.gelu(gate)
class ApproximateGELU(nn.Module):
"""
The approximate form of Gaussian Error Linear Unit (GELU)
For more details, see section 2: https://arxiv.org/abs/1606.08415
"""
def __init__(self, dim_in: int, dim_out: int):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
def forward(self, x):
x = self.proj(x)
return x * torch.sigmoid(1.702 * x)
class AdaLayerNorm(nn.Module):
"""
Norm layer modified to incorporate timestep embeddings.
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = nn.Embedding(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
def forward(self, x, timestep):
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x

View File

@@ -142,7 +142,7 @@ class FlaxBasicTransformerBlock(nn.Module):
return hidden_states
class FlaxSpatialTransformer(nn.Module):
class FlaxTransformer2DModel(nn.Module):
r"""
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
https://arxiv.org/pdf/1506.02025.pdf

View File

@@ -126,3 +126,68 @@ class GaussianFourierProjection(nn.Module):
else:
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
return out
class ImagePositionalEmbeddings(nn.Module):
"""
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
height and width of the latent space.
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
For VQ-diffusion:
Output vector embeddings are used as input for the transformer.
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
Args:
num_embed (`int`):
Number of embeddings for the latent pixels embeddings.
height (`int`):
Height of the latent image i.e. the number of height embeddings.
width (`int`):
Width of the latent image i.e. the number of width embeddings.
embed_dim (`int`):
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
"""
def __init__(
self,
num_embed: int,
height: int,
width: int,
embed_dim: int,
):
super().__init__()
self.height = height
self.width = width
self.num_embed = num_embed
self.embed_dim = embed_dim
self.emb = nn.Embedding(self.num_embed, embed_dim)
self.height_emb = nn.Embedding(self.height, embed_dim)
self.width_emb = nn.Embedding(self.width, embed_dim)
def forward(self, index):
emb = self.emb(index)
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
# 1 x H x D -> 1 x H x 1 x D
height_emb = height_emb.unsqueeze(2)
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
# 1 x W x D -> 1 x 1 x W x D
width_emb = width_emb.unsqueeze(1)
pos_emb = height_emb + width_emb
# 1 x H x W x D -> 1 x L xD
pos_emb = pos_emb.view(1, self.height * self.width, -1)
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, SpatialTransformer
from .attention import AttentionBlock, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -109,6 +109,19 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
)
elif down_block_type == "AttnDownEncoderBlock2D":
return AttnDownEncoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
@@ -200,6 +213,17 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -249,7 +273,7 @@ class UNetMidBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)
resnets.append(
@@ -325,13 +349,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
for _ in range(num_layers):
attentions.append(
SpatialTransformer(
in_channels,
Transformer2DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
resnets.append(
@@ -374,7 +398,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -427,7 +451,7 @@ class AttnDownBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)
@@ -506,13 +530,13 @@ class CrossAttnDownBlock2D(nn.Module):
)
)
attentions.append(
SpatialTransformer(
out_channels,
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -556,19 +580,22 @@ class CrossAttnDownBlock2D(nn.Module):
for resnet, attn in zip(self.resnets, self.attentions):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
return module(*inputs)
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
output_states += (hidden_states,)
@@ -763,7 +790,7 @@ class AttnDownEncoderBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)
@@ -1014,7 +1041,7 @@ class AttnUpBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)
@@ -1089,13 +1116,13 @@ class CrossAttnUpBlock2D(nn.Module):
)
)
attentions.append(
SpatialTransformer(
out_channels,
Transformer2DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
depth=1,
context_dim=cross_attention_dim,
num_groups=resnet_groups,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
)
)
self.attentions = nn.ModuleList(attentions)
@@ -1145,19 +1172,22 @@ class CrossAttnUpBlock2D(nn.Module):
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
return module(*inputs)
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
)[0]
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@@ -1337,7 +1367,7 @@ class AttnUpDecoderBlock2D(nn.Module):
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
num_groups=resnet_groups,
norm_num_groups=resnet_groups,
)
)

View File

@@ -15,7 +15,7 @@
import flax.linen as nn
import jax.numpy as jnp
from .attention_flax import FlaxSpatialTransformer
from .attention_flax import FlaxTransformer2DModel
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
@@ -63,7 +63,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
@@ -196,7 +196,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
)
resnets.append(res_block)
attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.out_channels,
n_heads=self.attn_num_head_channels,
d_head=self.out_channels // self.attn_num_head_channels,
@@ -326,7 +326,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
attentions = []
for _ in range(self.num_layers):
attn_block = FlaxSpatialTransformer(
attn_block = FlaxTransformer2DModel(
in_channels=self.in_channels,
n_heads=self.attn_num_head_channels,
d_head=self.in_channels // self.attn_num_head_channels,

View File

@@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module):
# NOTE: due to a bug the beta term was applied to the wrong term. for
# backwards compatibility we use the buggy version by default, but you can
# specify legacy=False to fix it.
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
def __init__(
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
):
super().__init__()
self.n_e = n_e
self.e_dim = e_dim
self.vq_embed_dim = vq_embed_dim
self.beta = beta
self.legacy = legacy
self.embedding = nn.Embedding(self.n_e, self.e_dim)
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
self.remap = remap
@@ -287,7 +289,7 @@ class VectorQuantizer(nn.Module):
def forward(self, z):
# reshape z -> (batch, height, width, channel) and flatten
z = z.permute(0, 2, 3, 1).contiguous()
z_flattened = z.view(-1, self.e_dim)
z_flattened = z.view(-1, self.vq_embed_dim)
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
d = (
@@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin):
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
sample_size (`int`, *optional*, defaults to `32`): TODO
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
"""
@register_to_config
@@ -425,6 +428,7 @@ class VQModel(ModelMixin, ConfigMixin):
sample_size: int = 32,
num_vq_embeddings: int = 256,
norm_num_groups: int = 32,
vq_embed_dim: Optional[int] = None,
):
super().__init__()
@@ -440,11 +444,11 @@ class VQModel(ModelMixin, ConfigMixin):
double_z=False,
)
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
self.quantize = VectorQuantizer(
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
)
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
# pass init params to Decoder
self.decoder = Decoder(

View File

@@ -21,6 +21,7 @@ if is_torch_available() and is_transformers_available():
StableDiffusionInpaintPipelineLegacy,
StableDiffusionPipeline,
)
from .vq_diffusion import VQDiffusionPipeline
if is_transformers_available() and is_onnx_available():
from .stable_diffusion import (

View File

@@ -0,0 +1 @@
from .pipeline_vq_diffusion import VQDiffusionPipeline

View File

@@ -0,0 +1,253 @@
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Callable, List, Optional, Tuple, Union
import torch
from diffusers import Transformer2DModel, VQModel
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from ...utils import logging
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class VQDiffusionPipeline(DiffusionPipeline):
r"""
Pipeline for text-to-image generation using VQ Diffusion
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Args:
vqvae ([`VQModel`]):
Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. VQ Diffusion uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
transformer ([`Transformer2DModel`]):
Conditional transformer to denoise the encoded image latents.
scheduler ([`VQDiffusionScheduler`]):
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
"""
vqvae: VQModel
text_encoder: CLIPTextModel
tokenizer: CLIPTokenizer
transformer: Transformer2DModel
scheduler: VQDiffusionScheduler
def __init__(
self,
vqvae: VQModel,
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
transformer: Transformer2DModel,
scheduler: VQDiffusionScheduler,
):
super().__init__()
self.register_modules(
vqvae=vqvae,
transformer=transformer,
text_encoder=text_encoder,
tokenizer=tokenizer,
scheduler=scheduler,
)
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_inference_steps: int = 100,
truncation_rate: float = 1.0,
num_images_per_prompt: int = 1,
generator: Optional[torch.Generator] = None,
latents: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
) -> Union[ImagePipelineOutput, Tuple]:
"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`):
The prompt or prompts to guide the image generation.
num_inference_steps (`int`, *optional*, defaults to 100):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
`truncation_rate` are set to zero.
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
latents (`torch.FloatTensor` of shape (batch), *optional*):
Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
be generated of completely masked latent pixels.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generated image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
Returns:
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
generated images.
"""
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
batch_size = batch_size * num_images_per_prompt
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
# While CLIP does normalize the pooled output of the text transformer when combining
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
#
# CLIP normalizing the pooled output.
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
# duplicate text embeddings for each generation per prompt
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
# get the initial completely masked latents unless the user supplied it
latents_shape = (batch_size, self.transformer.num_latent_pixels)
if latents is None:
mask_class = self.transformer.num_vector_embeds - 1
latents = torch.full(latents_shape, mask_class).to(self.device)
else:
if latents.shape != latents_shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
raise ValueError(
"Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
f" {self.transformer.num_vector_embeds - 1} (inclusive)."
)
latents = latents.to(self.device)
# set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
timesteps_tensor = self.scheduler.timesteps.to(self.device)
sample = latents
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
# predict the un-noised image
# model_output == `log_p_x_0`
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
model_output = self.truncate(model_output, truncation_rate)
# remove `log(0)`'s (`-inf`s)
model_output = model_output.clamp(-70)
# compute the previous noisy sample x_t -> x_t-1
sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
# call the callback, if provided
if callback is not None and i % callback_steps == 0:
callback(i, t, sample)
embedding_channels = self.vqvae.config.vq_embed_dim
embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
image = (image / 2 + 0.5).clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)
def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
"""
Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
"""
sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
# Ensure that at least the largest probability is not zeroed out
all_true = torch.full_like(keep_mask[:, 0:1, :], True)
keep_mask = torch.cat((all_true, keep_mask), dim=1)
keep_mask = keep_mask[:, :-1, :]
keep_mask = keep_mask.gather(1, indices.argsort(1))
rv = log_p_x_0.clone()
rv[~keep_mask] = -torch.inf # -inf = log(0)
return rv

View File

@@ -28,6 +28,7 @@ if is_torch_available():
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
else:
from ..utils.dummy_pt_objects import * # noqa F403

View File

@@ -0,0 +1,494 @@
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
class VQDiffusionSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
"""
prev_sample: torch.LongTensor
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
"""
Convert batch of vector of class indices into batch of log onehot vectors
Args:
x (`torch.LongTensor` of shape `(batch size, vector length)`):
Batch of class indices
num_classes (`int`):
number of classes to be used for the onehot vectors
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
Log onehot vectors
"""
x_onehot = F.one_hot(x, num_classes)
x_onehot = x_onehot.permute(0, 2, 1)
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
return log_x
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
"""
Apply gumbel noise to `logits`
"""
uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
noised = gumbel_noise + logits
return noised
def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
"""
Cumulative and non-cumulative alpha schedules.
See section 4.1.
"""
att = (
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
+ alpha_cum_start
)
att = np.concatenate(([1], att))
at = att[1:] / att[:-1]
att = np.concatenate((att[1:], [1]))
return at, att
def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
"""
Cumulative and non-cumulative gamma schedules.
See section 4.1.
"""
ctt = (
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
+ gamma_cum_start
)
ctt = np.concatenate(([0], ctt))
one_minus_ctt = 1 - ctt
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
ct = 1 - one_minus_ct
ctt = np.concatenate((ctt[1:], [0]))
return ct, ctt
class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
"""
The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
diffusion timestep.
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
[`~ConfigMixin.from_config`] functions.
For more details, see the original paper: https://arxiv.org/abs/2111.14822
Args:
num_vec_classes (`int`):
The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
latent pixel.
num_train_timesteps (`int`):
Number of diffusion steps used to train the model.
alpha_cum_start (`float`):
The starting cumulative alpha value.
alpha_cum_end (`float`):
The ending cumulative alpha value.
gamma_cum_start (`float`):
The starting cumulative gamma value.
gamma_cum_end (`float`):
The ending cumulative gamma value.
"""
@register_to_config
def __init__(
self,
num_vec_classes: int,
num_train_timesteps: int = 100,
alpha_cum_start: float = 0.99999,
alpha_cum_end: float = 0.000009,
gamma_cum_start: float = 0.000009,
gamma_cum_end: float = 0.99999,
):
self.num_embed = num_vec_classes
# By convention, the index for the mask class is the last class index
self.mask_class = self.num_embed - 1
at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
num_non_mask_classes = self.num_embed - 1
bt = (1 - at - ct) / num_non_mask_classes
btt = (1 - att - ctt) / num_non_mask_classes
at = torch.tensor(at.astype("float64"))
bt = torch.tensor(bt.astype("float64"))
ct = torch.tensor(ct.astype("float64"))
log_at = torch.log(at)
log_bt = torch.log(bt)
log_ct = torch.log(ct)
att = torch.tensor(att.astype("float64"))
btt = torch.tensor(btt.astype("float64"))
ctt = torch.tensor(ctt.astype("float64"))
log_cumprod_at = torch.log(att)
log_cumprod_bt = torch.log(btt)
log_cumprod_ct = torch.log(ctt)
self.log_at = log_at.float()
self.log_bt = log_bt.float()
self.log_ct = log_ct.float()
self.log_cumprod_at = log_cumprod_at.float()
self.log_cumprod_bt = log_cumprod_bt.float()
self.log_cumprod_ct = log_cumprod_ct.float()
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
device (`str` or `torch.device`):
device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
"""
self.num_inference_steps = num_inference_steps
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
self.timesteps = torch.from_numpy(timesteps).to(device)
self.log_at = self.log_at.to(device)
self.log_bt = self.log_bt.to(device)
self.log_ct = self.log_ct.to(device)
self.log_cumprod_at = self.log_cumprod_at.to(device)
self.log_cumprod_bt = self.log_cumprod_bt.to(device)
self.log_cumprod_ct = self.log_cumprod_ct.to(device)
def step(
self,
model_output: torch.FloatTensor,
timestep: torch.long,
sample: torch.LongTensor,
generator: Optional[torch.Generator] = None,
return_dict: bool = True,
) -> Union[VQDiffusionSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
t (`torch.long`):
The timestep that determines which transition matrices are used.
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`
generator: (`torch.Generator` or None):
RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
return_dict (`bool`):
option for returning tuple rather than VQDiffusionSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
When returning a tuple, the first element is the sample tensor.
"""
if timestep == 0:
log_p_x_t_min_1 = model_output
else:
log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
if not return_dict:
return (x_t_min_1,)
return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
def q_posterior(self, log_p_x_0, x_t, t):
"""
Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
forward probabilities.
Equation (11) stated in terms of forward probabilities via Equation (5):
Where:
- the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
Args:
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
prediction for the masked class as the initial unnoised image cannot be masked.
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`
t (torch.Long):
The timestep that determines which transition matrix is used.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
"""
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
)
log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
)
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
# . . .
# . . .
# . . .
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
q = log_p_x_0 - log_q_x_t_given_x_0
# sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
# sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
# . . .
# . . .
# . . .
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
q = q - q_log_sum_exp
# (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
# . . .
# . . .
# . . .
# (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
# c_cumulative_{t-1} ... c_cumulative_{t-1}
q = self.apply_cumulative_transitions(q, t - 1)
# ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
# . . .
# . . .
# . . .
# ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
# c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
# For each column, there are two possible cases.
#
# Where:
# - sum(p_n(x_0))) is summing over all classes for x_0
# - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
# - C_j is the class transitioning to
#
# 1. x_t is masked i.e. x_t = c_k
#
# Simplifying the expression, the column vector is:
# .
# .
# .
# (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
# .
# .
# .
# (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
#
# From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
#
# For the other rows, we can state the equation as ...
#
# (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
#
# This verifies the other rows.
#
# 2. x_t is not masked
#
# Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
# .
# .
# .
# C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
# .
# .
# .
# C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
# .
# .
# .
# 0
#
# The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
return log_p_x_t_min_1
def log_Q_t_transitioning_to_known_class(
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
):
"""
Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
latent pixel in `x_t`.
See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
Args:
t (torch.Long):
The timestep that determines which transition matrix is used.
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
The classes of each latent pixel at time `t`.
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
The log one-hot vectors of `x_t`
cumulative (`bool`):
If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
we use the cumulative transition matrix `0`->`t`.
Returns:
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
transition matrix.
When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
masked.
Where:
- `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
- C_0 is a class of a latent pixel embedding
- C_k is the class of the masked latent pixel
non-cumulative result (omitting logarithms):
```
q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
. . .
. . .
. . .
q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
```
cumulative result (omitting logarithms):
```
q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
. . .
. . .
. . .
q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
```
"""
if cumulative:
a = self.log_cumprod_at[t]
b = self.log_cumprod_bt[t]
c = self.log_cumprod_ct[t]
else:
a = self.log_at[t]
b = self.log_bt[t]
c = self.log_ct[t]
if not cumulative:
# The values in the onehot vector can also be used as the logprobs for transitioning
# from masked latent pixels. If we are not calculating the cumulative transitions,
# we need to save these vectors to be re-appended to the final matrix so the values
# aren't overwritten.
#
# `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
# if x_t is not masked
#
# `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
# if x_t is masked
log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
# `index_to_log_onehot` will add onehot vectors for masked pixels,
# so the default one hot matrix has one too many rows. See the doc string
# for an explanation of the dimensionality of the returned matrix.
log_onehot_x_t = log_onehot_x_t[:, :-1, :]
# this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
#
# Don't worry about what values this sets in the columns that mark transitions
# to masked latent pixels. They are overwrote later with the `mask_class_mask`.
#
# Looking at the below logspace formula in non-logspace, each value will evaluate to either
# `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
# or
# `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
#
# See equation 7 for more details.
log_Q_t = (log_onehot_x_t + a).logaddexp(b)
# The whole column of each masked pixel is `c`
mask_class_mask = x_t == self.mask_class
mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
log_Q_t[mask_class_mask] = c
if not cumulative:
log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
return log_Q_t
def apply_cumulative_transitions(self, q, t):
bsz = q.shape[0]
a = self.log_cumprod_at[t]
b = self.log_cumprod_bt[t]
c = self.log_cumprod_ct[t]
num_latent_pixels = q.shape[2]
c = c.expand(bsz, 1, num_latent_pixels)
q = (q + a).logaddexp(b)
q = torch.cat((q, c), dim=1)
return q

View File

@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class UNet1DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -257,6 +272,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class VQDiffusionPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class DDIMScheduler(metaclass=DummyObject):
_backends = ["torch"]
@@ -407,6 +437,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class EMAModel(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@@ -0,0 +1,175 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
from diffusers.utils import load_image, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def num_embed(self):
return 12
@property
def num_embeds_ada_norm(self):
return 12
@property
def dummy_vqvae(self):
torch.manual_seed(0)
model = VQModel(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=3,
num_vq_embeddings=self.num_embed,
vq_embed_dim=3,
)
return model
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=32,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModel(config)
@property
def dummy_transformer(self):
torch.manual_seed(0)
height = 12
width = 12
model_kwargs = {
"attention_bias": True,
"cross_attention_dim": 32,
"attention_head_dim": height * width,
"num_attention_heads": 1,
"num_vector_embeds": self.num_embed,
"num_embeds_ada_norm": self.num_embeds_ada_norm,
"norm_num_groups": 32,
"sample_size": width,
"activation_fn": "geglu-approximate",
}
model = Transformer2DModel(**model_kwargs)
return model
def test_vq_diffusion(self):
device = "cpu"
vqvae = self.dummy_vqvae
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
transformer = self.dummy_transformer
scheduler = VQDiffusionScheduler(self.num_embed)
pipe = VQDiffusionPipeline(
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
prompt = "teddy bear playing in the pool"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe(
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 24, 24, 3)
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_vq_diffusion(self):
expected_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/vq_diffusion/teddy_bear_pool.png"
)
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
"teddy bear playing in the pool",
truncation_rate=0.86,
num_images_per_prompt=1,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2

View File

@@ -18,8 +18,9 @@ import unittest
import numpy as np
import torch
from torch import nn
from diffusers.models.attention import AttentionBlock, SpatialTransformer
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.resnet import Downsample2D, Upsample2D
from diffusers.utils import torch_device
@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase):
num_head_channels=1,
rescale_output_factor=1.0,
eps=1e-6,
num_groups=32,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase):
channels=512,
rescale_output_factor=1.0,
eps=1e-6,
num_groups=32,
norm_num_groups=32,
).to(torch_device)
with torch.no_grad():
attention_scores = attentionBlock(sample)
@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase):
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
class SpatialTransformerTests(unittest.TestCase):
class Transformer2DModelTests(unittest.TestCase):
def test_spatial_transformer_default(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer(
spatial_transformer_block = Transformer2DModel(
in_channels=32,
n_heads=1,
d_head=32,
num_attention_heads=1,
attention_head_dim=32,
dropout=0.0,
context_dim=None,
cross_attention_dim=None,
).to(torch_device)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample)
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_context_dim(self):
def test_spatial_transformer_cross_attention_dim(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = SpatialTransformer(
spatial_transformer_block = Transformer2DModel(
in_channels=64,
n_heads=2,
d_head=32,
num_attention_heads=2,
attention_head_dim=32,
dropout=0.0,
context_dim=64,
cross_attention_dim=64,
).to(torch_device)
with torch.no_grad():
context = torch.randn(1, 4, 64).to(torch_device)
attention_scores = spatial_transformer_block(sample, context)
attention_scores = spatial_transformer_block(sample, context).sample
assert attention_scores.shape == (1, 64, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase):
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_timestep(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embeds_ada_norm = 5
sample = torch.randn(1, 64, 64, 64).to(torch_device)
spatial_transformer_block = Transformer2DModel(
in_channels=64,
num_attention_heads=2,
attention_head_dim=32,
dropout=0.0,
cross_attention_dim=64,
num_embeds_ada_norm=num_embeds_ada_norm,
).to(torch_device)
with torch.no_grad():
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
assert attention_scores_1.shape == (1, 64, 64, 64)
assert attention_scores_2.shape == (1, 64, 64, 64)
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
expected_slice_1 = torch.tensor(
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
)
expected_slice_2 = torch.tensor(
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
)
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
def test_spatial_transformer_dropout(self):
torch.manual_seed(0)
if torch.cuda.is_available():
@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase):
sample = torch.randn(1, 32, 64, 64).to(torch_device)
spatial_transformer_block = (
SpatialTransformer(
Transformer2DModel(
in_channels=32,
n_heads=2,
d_head=16,
num_attention_heads=2,
attention_head_dim=16,
dropout=0.3,
context_dim=None,
cross_attention_dim=None,
)
.to(torch_device)
.eval()
)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample)
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, 32, 64, 64)
output_slice = attention_scores[0, -1, -3:, -3:]
@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase):
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
def test_spatial_transformer_discrete(self):
torch.manual_seed(0)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(0)
num_embed = 5
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
spatial_transformer_block = (
Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
num_vector_embeds=num_embed,
sample_size=16,
)
.to(torch_device)
.eval()
)
with torch.no_grad():
attention_scores = spatial_transformer_block(sample).sample
assert attention_scores.shape == (1, num_embed - 1, 32)
output_slice = attention_scores[0, -2:, -3:]
expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device)
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
def test_spatial_transformer_default_norm_layers(self):
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_ada_norm_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
num_embeds_ada_norm=5,
)
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
def test_spatial_transformer_default_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
# NOTE: inner_dim * 2 because GEGLU
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_geglu_approx_ff_layers(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1,
attention_head_dim=32,
in_channels=32,
activation_fn="geglu-approximate",
)
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
dim = 32
inner_dim = 128
# First dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
# Second dimension change
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
def test_spatial_transformer_attention_bias(self):
spatial_transformer_block = Transformer2DModel(
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
)
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None

View File

@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple
import numpy as np
import torch
import torch.nn.functional as F
from diffusers import (
DDIMScheduler,
@@ -29,6 +30,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
ScoreSdeVeScheduler,
VQDiffusionScheduler,
)
from diffusers.utils import torch_device
@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config(**config)
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase):
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
time_step = float(time_step)
sample = self.dummy_sample
residual = 0.1 * sample
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, time_step)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
new_scheduler = scheduler_class.from_config(tmpdirname)
@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase):
num_inference_steps = kwargs.pop("num_inference_steps", None)
for scheduler_class in self.scheduler_classes:
sample = self.dummy_sample
residual = 0.1 * sample
timestep = 1
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timestep = 1
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
timestep = float(timestep)
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample
residual = 0.1 * sample
with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_config(tmpdirname)
@@ -200,8 +220,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep_0)
else:
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -255,8 +281,14 @@ class SchedulerCommonTest(unittest.TestCase):
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
sample = self.dummy_sample
residual = 0.1 * sample
if scheduler_class == VQDiffusionScheduler:
num_vec_classes = scheduler_config["num_vec_classes"]
sample = self.dummy_sample(num_vec_classes)
model = self.dummy_model(num_vec_classes)
residual = model(sample, timestep)
else:
sample = self.dummy_sample
residual = 0.1 * sample
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
scheduler.set_timesteps(num_inference_steps)
@@ -284,22 +316,26 @@ class SchedulerCommonTest(unittest.TestCase):
for scheduler_class in self.scheduler_classes:
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
)
if scheduler_class != VQDiffusionScheduler:
self.assertTrue(
hasattr(scheduler, "init_noise_sigma"),
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
)
self.assertTrue(
hasattr(scheduler, "scale_model_input"),
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
" timestep)`",
)
self.assertTrue(
hasattr(scheduler, "step"),
f"{scheduler_class} does not implement a required class method `step(...)`",
)
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
if scheduler_class != VQDiffusionScheduler:
sample = self.dummy_sample
scaled_sample = scheduler.scale_model_input(sample, 0.0)
self.assertEqual(sample.shape, scaled_sample.shape)
def test_add_noise_device(self):
for scheduler_class in self.scheduler_classes:
@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
result_mean = torch.mean(torch.abs(sample))
assert abs(result_mean.item() - 2540529) < 10
class VQDiffusionSchedulerTest(SchedulerCommonTest):
scheduler_classes = (VQDiffusionScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_vec_classes": 4097,
"num_train_timesteps": 100,
}
config.update(**kwargs)
return config
def dummy_sample(self, num_vec_classes):
batch_size = 4
height = 8
width = 8
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
return sample
@property
def dummy_sample_deter(self):
assert False
def dummy_model(self, num_vec_classes):
def model(sample, t, *args):
batch_size, num_latent_pixels = sample.shape
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
return_value = F.log_softmax(logits.double(), dim=1).float()
return return_value
return model
def test_timesteps(self):
for timesteps in [2, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_num_vec_classes(self):
for num_vec_classes in [5, 100, 1000, 4000]:
self.check_over_configs(num_vec_classes=num_vec_classes)
def test_time_indices(self):
for t in [0, 50, 99]:
self.check_over_forward(time_step=t)
def test_add_noise_device(self):
pass