mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
VQ-diffusion (#658)
* Changes for VQ-diffusion VQVAE Add specify dimension of embeddings to VQModel: `VQModel` will by default set the dimension of embeddings to the number of latent channels. The VQ-diffusion VQVAE has a smaller embedding dimension, 128, than number of latent channels, 256. Add AttnDownEncoderBlock2D and AttnUpDecoderBlock2D to the up and down unet block helpers. VQ-diffusion's VQVAE uses those two block types. * Changes for VQ-diffusion transformer Modify attention.py so SpatialTransformer can be used for VQ-diffusion's transformer. SpatialTransformer: - Can now operate over discrete inputs (classes of vector embeddings) as well as continuous. - `in_channels` was made optional in the constructor so two locations where it was passed as a positional arg were moved to kwargs - modified forward pass to take optional timestep embeddings ImagePositionalEmbeddings: - added to provide positional embeddings to discrete inputs for latent pixels BasicTransformerBlock: - norm layers were made configurable so that the VQ-diffusion could use AdaLayerNorm with timestep embeddings - modified forward pass to take optional timestep embeddings CrossAttention: - now may optionally take a bias parameter for its query, key, and value linear layers FeedForward: - Internal layers are now configurable ApproximateGELU: - Activation function in VQ-diffusion's feedforward layer AdaLayerNorm: - Norm layer modified to incorporate timestep embeddings * Add VQ-diffusion scheduler * Add VQ-diffusion pipeline * Add VQ-diffusion convert script to diffusers * Add VQ-diffusion dummy objects * Add VQ-diffusion markdown docs * Add VQ-diffusion tests * some renaming * some fixes * more renaming * correct * fix typo * correct weights * finalize * fix tests * Apply suggestions from code review Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> * Apply suggestions from code review Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * finish * finish * up Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Anton Lozhkov <aglozhkov@gmail.com> Co-authored-by: Pedro Cuenca <pedro@huggingface.co>
This commit is contained in:
@@ -96,6 +96,8 @@
|
||||
title: "Stochastic Karras VE"
|
||||
- local: api/pipelines/dance_diffusion
|
||||
title: "Dance Diffusion"
|
||||
- local: api/pipelines/vq_diffusion
|
||||
title: "VQ Diffusion"
|
||||
- local: api/pipelines/repaint
|
||||
title: "RePaint"
|
||||
title: "Pipelines"
|
||||
|
||||
@@ -49,6 +49,12 @@ The models are built on the base class ['ModelMixin'] that is a `torch.nn.module
|
||||
## AutoencoderKL
|
||||
[[autodoc]] AutoencoderKL
|
||||
|
||||
## Transformer2DModel
|
||||
[[autodoc]] Transformer2DModel
|
||||
|
||||
## Transformer2DModelOutput
|
||||
[[autodoc]] models.attention.Transformer2DModelOutput
|
||||
|
||||
## FlaxModelMixin
|
||||
[[autodoc]] FlaxModelMixin
|
||||
|
||||
|
||||
@@ -41,22 +41,22 @@ If you are looking for *official* training examples, please have a look at [exam
|
||||
The following table summarizes all officially supported pipelines, their corresponding paper, and if
|
||||
available a colab notebook to directly try them out.
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------|:-------------------------------------:|:---:|
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [ddpm](./ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [latent_diffusion](./latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
| [latent_diffusion_uncond](./latent_diffusion_uncond) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752) | Unconditional Image Generation |
|
||||
| [pndm](./pndm) | [**Pseudo Numerical Methods for Diffusion Models on Manifolds**](https://arxiv.org/abs/2202.09778) | Unconditional Image Generation |
|
||||
| [score_sde_ve](./score_sde_ve) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [score_sde_vp](./score_sde_vp) | [**Score-Based Generative Modeling through Stochastic Differential Equations**](https://openreview.net/forum?id=PxTIG12RRHS) | Unconditional Image Generation |
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-to-Image Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./vq_diffusion) | [**Vector Quantized Diffusion Model for Text-to-Image Synthesis**](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
| [repaint](./repaint) | [**RePaint: Inpainting using Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2201.09865) | Image Inpainting |
|
||||
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
However, most of them can be adapted to use different scheduler components or even different model components. Some pipeline examples are shown in the [Examples](#examples) below.
|
||||
|
||||
34
docs/source/api/pipelines/vq_diffusion.mdx
Normal file
34
docs/source/api/pipelines/vq_diffusion.mdx
Normal file
@@ -0,0 +1,34 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# VQDiffusion
|
||||
|
||||
## Overview
|
||||
|
||||
[Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) by Shuyang Gu, Dong Chen, Jianmin Bao, Fang Wen, Bo Zhang, Dongdong Chen, Lu Yuan, Baining Guo
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
We present the vector quantized diffusion (VQ-Diffusion) model for text-to-image generation. This method is based on a vector quantized variational autoencoder (VQ-VAE) whose latent space is modeled by a conditional variant of the recently developed Denoising Diffusion Probabilistic Model (DDPM). We find that this latent-space method is well-suited for text-to-image generation tasks because it not only eliminates the unidirectional bias with existing methods but also allows us to incorporate a mask-and-replace diffusion strategy to avoid the accumulation of errors, which is a serious problem with existing methods. Our experiments show that the VQ-Diffusion produces significantly better text-to-image generation results when compared with conventional autoregressive (AR) models with similar numbers of parameters. Compared with previous GAN-based text-to-image methods, our VQ-Diffusion can handle more complex scenes and improve the synthesized image quality by a large margin. Finally, we show that the image generation computation in our method can be made highly efficient by reparameterization. With traditional AR methods, the text-to-image generation time increases linearly with the output image resolution and hence is quite time consuming even for normal size images. The VQ-Diffusion allows us to achieve a better trade-off between quality and speed. Our experiments indicate that the VQ-Diffusion model with the reparameterization is fifteen times faster than traditional AR methods while achieving a better image quality.
|
||||
|
||||
The original codebase can be found [here](https://github.com/microsoft/VQ-Diffusion).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_vq_diffusion.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py) | *Text-to-Image Generation* | - |
|
||||
|
||||
|
||||
## VQDiffusionPipeline
|
||||
[[autodoc]] pipelines.vq_diffusion.pipeline_vq_diffusion.VQDiffusionPipeline
|
||||
- __call__
|
||||
@@ -113,7 +113,6 @@ Score SDE-VP is under construction.
|
||||
|
||||
[[autodoc]] schedulers.scheduling_sde_vp.ScoreSdeVpScheduler
|
||||
|
||||
|
||||
#### Euler scheduler
|
||||
|
||||
Euler scheduler (Algorithm 2) from the paper [Elucidating the Design Space of Diffusion-Based Generative Models](https://arxiv.org/abs/2206.00364) by Karras et al. (2022). Based on the original [k-diffusion](https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L51) implementation by Katherine Crowson.
|
||||
@@ -130,6 +129,12 @@ Fast scheduler which often times generates good outputs with 20-30 steps.
|
||||
[[autodoc]] EulerAncestralDiscreteScheduler
|
||||
|
||||
|
||||
#### VQDiffusionScheduler
|
||||
|
||||
Original paper can be found [here](https://arxiv.org/abs/2111.14822)
|
||||
|
||||
[[autodoc]] VQDiffusionScheduler
|
||||
|
||||
#### RePaint scheduler
|
||||
|
||||
DDPM-based inpainting scheduler for unsupervised inpainting with extreme masks.
|
||||
@@ -137,4 +142,4 @@ Intended for use with [`RePaintPipeline`].
|
||||
Based on the paper [RePaint: Inpainting using Denoising Diffusion Probabilistic Models](https://arxiv.org/abs/2201.09865)
|
||||
and the original implementation by Andreas Lugmayr et al.: https://github.com/andreas128/RePaint
|
||||
|
||||
[[autodoc]] RePaintScheduler
|
||||
[[autodoc]] RePaintScheduler
|
||||
|
||||
@@ -34,6 +34,7 @@ available a colab notebook to directly try them out.
|
||||
|
||||
| Pipeline | Paper | Tasks | Colab
|
||||
|---|---|:---:|:---:|
|
||||
| [dance_diffusion](./api/pipelines/dance_diffusion) | [**Dance Diffusion**](https://github.com/williamberman/diffusers.git) | Unconditional Audio Generation |
|
||||
| [ddpm](./api/pipelines/ddpm) | [**Denoising Diffusion Probabilistic Models**](https://arxiv.org/abs/2006.11239) | Unconditional Image Generation |
|
||||
| [ddim](./api/pipelines/ddim) | [**Denoising Diffusion Implicit Models**](https://arxiv.org/abs/2010.02502) | Unconditional Image Generation |
|
||||
| [latent_diffusion](./api/pipelines/latent_diffusion) | [**High-Resolution Image Synthesis with Latent Diffusion Models**](https://arxiv.org/abs/2112.10752)| Text-to-Image Generation |
|
||||
@@ -45,5 +46,6 @@ available a colab notebook to directly try them out.
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Image-to-Image Text-Guided Generation | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/image_2_image_using_diffusers.ipynb)
|
||||
| [stable_diffusion](./api/pipelines/stable_diffusion) | [**Stable Diffusion**](https://stability.ai/blog/stable-diffusion-public-release) | Text-Guided Image Inpainting | [](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/in_painting_with_stable_diffusion_using_diffusers.ipynb)
|
||||
| [stochastic_karras_ve](./api/pipelines/stochastic_karras_ve) | [**Elucidating the Design Space of Diffusion-Based Generative Models**](https://arxiv.org/abs/2206.00364) | Unconditional Image Generation |
|
||||
| [vq_diffusion](./api/pipelines/vq_diffusion) | [Vector Quantized Diffusion Model for Text-to-Image Synthesis](https://arxiv.org/abs/2111.14822) | Text-to-Image Generation |
|
||||
|
||||
**Note**: Pipelines are simple examples of how to play around with the diffusion systems as described in the corresponding papers.
|
||||
|
||||
885
scripts/convert_vq_diffusion_to_diffusers.py
Normal file
885
scripts/convert_vq_diffusion_to_diffusers.py
Normal file
@@ -0,0 +1,885 @@
|
||||
"""
|
||||
This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers.
|
||||
|
||||
It currently only supports porting the ITHQ dataset.
|
||||
|
||||
ITHQ dataset:
|
||||
```sh
|
||||
# From the root directory of diffusers.
|
||||
|
||||
# Download the VQVAE checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth
|
||||
|
||||
# Download the VQVAE config
|
||||
# NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class
|
||||
# `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE`
|
||||
# loads `OUTPUT/pretrained_model/taming_dvae/config.yaml`
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml
|
||||
|
||||
# Download the main model checkpoint
|
||||
$ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth
|
||||
|
||||
# Download the main model config
|
||||
$ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml
|
||||
|
||||
# run the convert script
|
||||
$ python ./scripts/convert_vq_diffusion_to_diffusers.py \
|
||||
--checkpoint_path ./ithq_learnable.pth \
|
||||
--original_config_file ./ithq.yaml \
|
||||
--vqvae_checkpoint_path ./ithq_vqvae.pth \
|
||||
--vqvae_original_config_file ./ithq_vqvae.yaml \
|
||||
--dump_path <path to save pre-trained `VQDiffusionPipeline`>
|
||||
```
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import tempfile
|
||||
|
||||
import torch
|
||||
|
||||
import yaml
|
||||
from accelerate import init_empty_weights, load_checkpoint_and_dispatch
|
||||
from diffusers import VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.models.attention import Transformer2DModel
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
from yaml.loader import FullLoader
|
||||
|
||||
|
||||
try:
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install"
|
||||
" OmegaConf`."
|
||||
)
|
||||
|
||||
# vqvae model
|
||||
|
||||
PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"]
|
||||
|
||||
|
||||
def vqvae_model_from_original_config(original_config):
|
||||
assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_config = original_config.params
|
||||
|
||||
original_encoder_config = original_config.encoder_config.params
|
||||
original_decoder_config = original_config.decoder_config.params
|
||||
|
||||
in_channels = original_encoder_config.in_channels
|
||||
out_channels = original_decoder_config.out_ch
|
||||
|
||||
down_block_types = get_down_block_types(original_encoder_config)
|
||||
up_block_types = get_up_block_types(original_decoder_config)
|
||||
|
||||
assert original_encoder_config.ch == original_decoder_config.ch
|
||||
assert original_encoder_config.ch_mult == original_decoder_config.ch_mult
|
||||
block_out_channels = tuple(
|
||||
[original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult]
|
||||
)
|
||||
|
||||
assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks
|
||||
layers_per_block = original_encoder_config.num_res_blocks
|
||||
|
||||
assert original_encoder_config.z_channels == original_decoder_config.z_channels
|
||||
latent_channels = original_encoder_config.z_channels
|
||||
|
||||
num_vq_embeddings = original_config.n_embed
|
||||
|
||||
# Hard coded value for ResnetBlock.GoupNorm(num_groups) in VQ-diffusion
|
||||
norm_num_groups = 32
|
||||
|
||||
e_dim = original_config.embed_dim
|
||||
|
||||
model = VQModel(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
down_block_types=down_block_types,
|
||||
up_block_types=up_block_types,
|
||||
block_out_channels=block_out_channels,
|
||||
layers_per_block=layers_per_block,
|
||||
latent_channels=latent_channels,
|
||||
num_vq_embeddings=num_vq_embeddings,
|
||||
norm_num_groups=norm_num_groups,
|
||||
vq_embed_dim=e_dim,
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def get_down_block_types(original_encoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_encoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_encoder_config.resolution)
|
||||
|
||||
curr_res = resolution
|
||||
down_block_types = []
|
||||
|
||||
for _ in range(num_resolutions):
|
||||
if curr_res in attn_resolutions:
|
||||
down_block_type = "AttnDownEncoderBlock2D"
|
||||
else:
|
||||
down_block_type = "DownEncoderBlock2D"
|
||||
|
||||
down_block_types.append(down_block_type)
|
||||
|
||||
curr_res = [r // 2 for r in curr_res]
|
||||
|
||||
return down_block_types
|
||||
|
||||
|
||||
def get_up_block_types(original_decoder_config):
|
||||
attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions)
|
||||
num_resolutions = len(original_decoder_config.ch_mult)
|
||||
resolution = coerce_resolution(original_decoder_config.resolution)
|
||||
|
||||
curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution]
|
||||
up_block_types = []
|
||||
|
||||
for _ in reversed(range(num_resolutions)):
|
||||
if curr_res in attn_resolutions:
|
||||
up_block_type = "AttnUpDecoderBlock2D"
|
||||
else:
|
||||
up_block_type = "UpDecoderBlock2D"
|
||||
|
||||
up_block_types.append(up_block_type)
|
||||
|
||||
curr_res = [r * 2 for r in curr_res]
|
||||
|
||||
return up_block_types
|
||||
|
||||
|
||||
def coerce_attn_resolutions(attn_resolutions):
|
||||
attn_resolutions = OmegaConf.to_object(attn_resolutions)
|
||||
attn_resolutions_ = []
|
||||
for ar in attn_resolutions:
|
||||
if isinstance(ar, (list, tuple)):
|
||||
attn_resolutions_.append(list(ar))
|
||||
else:
|
||||
attn_resolutions_.append([ar, ar])
|
||||
return attn_resolutions_
|
||||
|
||||
|
||||
def coerce_resolution(resolution):
|
||||
resolution = OmegaConf.to_object(resolution)
|
||||
if isinstance(resolution, int):
|
||||
resolution = [resolution, resolution] # H, W
|
||||
elif isinstance(resolution, (tuple, list)):
|
||||
resolution = list(resolution)
|
||||
else:
|
||||
raise ValueError("Unknown type of resolution:", resolution)
|
||||
return resolution
|
||||
|
||||
|
||||
# done vqvae model
|
||||
|
||||
# vqvae checkpoint
|
||||
|
||||
|
||||
def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
# quant_conv
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"quant_conv.weight": checkpoint["quant_conv.weight"],
|
||||
"quant_conv.bias": checkpoint["quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# quantize
|
||||
diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]})
|
||||
|
||||
# post_quant_conv
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"post_quant_conv.weight": checkpoint["post_quant_conv.weight"],
|
||||
"post_quant_conv.bias": checkpoint["post_quant_conv.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# decoder
|
||||
diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint))
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv_in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"],
|
||||
"encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# down_blocks
|
||||
for down_block_idx, down_block in enumerate(model.encoder.down_blocks):
|
||||
diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}"
|
||||
down_block_prefix = f"encoder.down.{down_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(down_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# downsample
|
||||
|
||||
# do not include the downsample when on the last down block
|
||||
# There is no downsample on the last down block
|
||||
if down_block_idx != len(model.encoder.down_blocks) - 1:
|
||||
# There's a single downsample in the original checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv"
|
||||
downsample_prefix = f"{down_block_prefix}.downsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(down_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(down_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{down_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion encoder
|
||||
diffusers_attention_prefix = "encoder.mid_block.attentions.0"
|
||||
attention_prefix = "encoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion encoder
|
||||
resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"],
|
||||
"encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"],
|
||||
"encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
# conv in
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
"decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"],
|
||||
"decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# up_blocks
|
||||
|
||||
for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks):
|
||||
# up_blocks are stored in reverse order in the VQ-diffusion checkpoint
|
||||
orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx
|
||||
|
||||
diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}"
|
||||
up_block_prefix = f"decoder.up.{orig_up_block_idx}"
|
||||
|
||||
# resnets
|
||||
for resnet_idx, resnet in enumerate(up_block.resnets):
|
||||
diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}"
|
||||
resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# upsample
|
||||
|
||||
# there is no up sample on the last up block
|
||||
if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1:
|
||||
# There's a single upsample in the VQ-diffusion checkpoint but a list of downsamples
|
||||
# in the diffusers model.
|
||||
diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv"
|
||||
downsample_prefix = f"{up_block_prefix}.upsample.conv"
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"],
|
||||
f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# attentions
|
||||
|
||||
if hasattr(up_block, "attentions"):
|
||||
for attention_idx, _ in enumerate(up_block.attentions):
|
||||
diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}"
|
||||
attention_prefix = f"{up_block_prefix}.attn.{attention_idx}"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_attention_prefix=diffusers_attention_prefix,
|
||||
attention_prefix=attention_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# mid block
|
||||
|
||||
# mid block attentions
|
||||
|
||||
# There is a single hardcoded attention block in the middle of the VQ-diffusion decoder
|
||||
diffusers_attention_prefix = "decoder.mid_block.attentions.0"
|
||||
attention_prefix = "decoder.mid.attn_1"
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# mid block resnets
|
||||
|
||||
for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets):
|
||||
diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}"
|
||||
|
||||
# the hardcoded prefixes to `block_` are 1 and 2
|
||||
orig_resnet_idx = diffusers_resnet_idx + 1
|
||||
# There are two hardcoded resnets in the middle of the VQ-diffusion decoder
|
||||
resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
vqvae_resnet_to_diffusers_checkpoint(
|
||||
resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix
|
||||
)
|
||||
)
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
# conv_norm_out
|
||||
"decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"],
|
||||
"decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"],
|
||||
# conv_out
|
||||
"decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"],
|
||||
"decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix):
|
||||
rv = {
|
||||
# norm1
|
||||
f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"],
|
||||
# conv1
|
||||
f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"],
|
||||
# norm2
|
||||
f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"],
|
||||
f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"],
|
||||
# conv2
|
||||
f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"],
|
||||
}
|
||||
|
||||
if resnet.conv_shortcut is not None:
|
||||
rv.update(
|
||||
{
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"],
|
||||
f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return rv
|
||||
|
||||
|
||||
def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# group_norm
|
||||
f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"],
|
||||
f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"],
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0],
|
||||
f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"],
|
||||
# proj_attn
|
||||
f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][
|
||||
:, :, 0, 0
|
||||
],
|
||||
f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done vqvae checkpoint
|
||||
|
||||
# transformer model
|
||||
|
||||
PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"]
|
||||
PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"]
|
||||
PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"]
|
||||
|
||||
|
||||
def transformer_model_from_original_config(
|
||||
original_diffusion_config, original_transformer_config, original_content_embedding_config
|
||||
):
|
||||
assert (
|
||||
original_diffusion_config.target in PORTED_DIFFUSIONS
|
||||
), f"{original_diffusion_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_transformer_config.target in PORTED_TRANSFORMERS
|
||||
), f"{original_transformer_config.target} has not yet been ported to diffusers."
|
||||
assert (
|
||||
original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS
|
||||
), f"{original_content_embedding_config.target} has not yet been ported to diffusers."
|
||||
|
||||
original_diffusion_config = original_diffusion_config.params
|
||||
original_transformer_config = original_transformer_config.params
|
||||
original_content_embedding_config = original_content_embedding_config.params
|
||||
|
||||
inner_dim = original_transformer_config["n_embd"]
|
||||
|
||||
n_heads = original_transformer_config["n_head"]
|
||||
|
||||
# VQ-Diffusion gives dimension of the multi-headed attention layers as the
|
||||
# number of attention heads times the sequence length (the dimension) of a
|
||||
# single head. We want to specify our attention blocks with those values
|
||||
# specified separately
|
||||
assert inner_dim % n_heads == 0
|
||||
d_head = inner_dim // n_heads
|
||||
|
||||
depth = original_transformer_config["n_layer"]
|
||||
context_dim = original_transformer_config["condition_dim"]
|
||||
|
||||
num_embed = original_content_embedding_config["num_embed"]
|
||||
# the number of embeddings in the transformer includes the mask embedding.
|
||||
# the content embedding (the vqvae) does not include the mask embedding.
|
||||
num_embed = num_embed + 1
|
||||
|
||||
height = original_transformer_config["content_spatial_size"][0]
|
||||
width = original_transformer_config["content_spatial_size"][1]
|
||||
|
||||
assert width == height, "width has to be equal to height"
|
||||
dropout = original_transformer_config["resid_pdrop"]
|
||||
num_embeds_ada_norm = original_diffusion_config["diffusion_step"]
|
||||
|
||||
model_kwargs = {
|
||||
"attention_bias": True,
|
||||
"cross_attention_dim": context_dim,
|
||||
"attention_head_dim": d_head,
|
||||
"num_layers": depth,
|
||||
"dropout": dropout,
|
||||
"num_attention_heads": n_heads,
|
||||
"num_vector_embeds": num_embed,
|
||||
"num_embeds_ada_norm": num_embeds_ada_norm,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": width,
|
||||
"activation_fn": "geglu-approximate",
|
||||
}
|
||||
|
||||
model = Transformer2DModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
# done transformer model
|
||||
|
||||
# transformer checkpoint
|
||||
|
||||
|
||||
def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint):
|
||||
diffusers_checkpoint = {}
|
||||
|
||||
transformer_prefix = "transformer.transformer"
|
||||
|
||||
diffusers_latent_image_embedding_prefix = "latent_image_embedding"
|
||||
latent_image_embedding_prefix = f"{transformer_prefix}.content_emb"
|
||||
|
||||
# DalleMaskImageEmbedding
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.height_emb.weight"
|
||||
],
|
||||
f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[
|
||||
f"{latent_image_embedding_prefix}.width_emb.weight"
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
# transformer blocks
|
||||
for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks):
|
||||
diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}"
|
||||
transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}"
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# ada norm block
|
||||
diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2"
|
||||
ada_norm_prefix = f"{transformer_block_prefix}.ln1_1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_ada_norm_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# attention block
|
||||
diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2"
|
||||
attention_prefix = f"{transformer_block_prefix}.attn2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_attention_to_diffusers_checkpoint(
|
||||
checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix
|
||||
)
|
||||
)
|
||||
|
||||
# norm block
|
||||
diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3"
|
||||
norm_block_prefix = f"{transformer_block_prefix}.ln2"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"],
|
||||
f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
# feedforward block
|
||||
diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff"
|
||||
feedforward_prefix = f"{transformer_block_prefix}.mlp"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
transformer_feedforward_to_diffusers_checkpoint(
|
||||
checkpoint,
|
||||
diffusers_feedforward_prefix=diffusers_feedforward_prefix,
|
||||
feedforward_prefix=feedforward_prefix,
|
||||
)
|
||||
)
|
||||
|
||||
# to logits
|
||||
|
||||
diffusers_norm_out_prefix = "norm_out"
|
||||
norm_out_prefix = f"{transformer_prefix}.to_logits.0"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"],
|
||||
f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
diffusers_out_prefix = "out"
|
||||
out_prefix = f"{transformer_prefix}.to_logits.1"
|
||||
|
||||
diffusers_checkpoint.update(
|
||||
{
|
||||
f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"],
|
||||
f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"],
|
||||
}
|
||||
)
|
||||
|
||||
return diffusers_checkpoint
|
||||
|
||||
|
||||
def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix):
|
||||
return {
|
||||
f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"],
|
||||
f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix):
|
||||
return {
|
||||
# key
|
||||
f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"],
|
||||
f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"],
|
||||
# query
|
||||
f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"],
|
||||
f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"],
|
||||
# value
|
||||
f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"],
|
||||
f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"],
|
||||
# linear out
|
||||
f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"],
|
||||
f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"],
|
||||
}
|
||||
|
||||
|
||||
def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix):
|
||||
return {
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"],
|
||||
f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"],
|
||||
}
|
||||
|
||||
|
||||
# done transformer checkpoint
|
||||
|
||||
|
||||
def read_config_file(filename):
|
||||
# The yaml file contains annotations that certain values should
|
||||
# loaded as tuples. By default, OmegaConf will panic when reading
|
||||
# these. Instead, we can manually read the yaml with the FullLoader and then
|
||||
# construct the OmegaConf object.
|
||||
with open(filename) as f:
|
||||
original_config = yaml.load(f, FullLoader)
|
||||
|
||||
return OmegaConf.create(original_config)
|
||||
|
||||
|
||||
# We take separate arguments for the vqvae because the ITHQ vqvae config file
|
||||
# is separate from the config file for the rest of the model.
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_checkpoint_path",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="Path to the vqvae checkpoint to convert.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--vqvae_original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture for the vqvae.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--original_config_file",
|
||||
default=None,
|
||||
type=str,
|
||||
required=True,
|
||||
help="The YAML config file corresponding to the original architecture.",
|
||||
)
|
||||
|
||||
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
||||
|
||||
parser.add_argument(
|
||||
"--checkpoint_load_device",
|
||||
default="cpu",
|
||||
type=str,
|
||||
required=False,
|
||||
help="The device passed to `map_location` when loading checkpoints.",
|
||||
)
|
||||
|
||||
# See link for how ema weights are always selected
|
||||
# https://github.com/microsoft/VQ-Diffusion/blob/3c98e77f721db7c787b76304fa2c96a36c7b00af/inference_VQ_Diffusion.py#L65
|
||||
parser.add_argument(
|
||||
"--no_use_ema",
|
||||
action="store_true",
|
||||
required=False,
|
||||
help=(
|
||||
"Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set"
|
||||
" it as the original VQ-Diffusion always uses the ema weights when loading models."
|
||||
),
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
use_ema = not args.no_use_ema
|
||||
|
||||
print(f"loading checkpoints to {args.checkpoint_load_device}")
|
||||
|
||||
checkpoint_map_location = torch.device(args.checkpoint_load_device)
|
||||
|
||||
# vqvae_model
|
||||
|
||||
print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}")
|
||||
|
||||
vqvae_original_config = read_config_file(args.vqvae_original_config_file).model
|
||||
vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"]
|
||||
|
||||
with init_empty_weights():
|
||||
vqvae_model = vqvae_model_from_original_config(vqvae_original_config)
|
||||
|
||||
vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file:
|
||||
torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name)
|
||||
del vqvae_diffusers_checkpoint
|
||||
del vqvae_checkpoint
|
||||
load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading vqvae")
|
||||
|
||||
# done vqvae_model
|
||||
|
||||
# transformer_model
|
||||
|
||||
print(
|
||||
f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:"
|
||||
f" {use_ema}"
|
||||
)
|
||||
|
||||
original_config = read_config_file(args.original_config_file).model
|
||||
|
||||
diffusion_config = original_config.params.diffusion_config
|
||||
transformer_config = original_config.params.diffusion_config.params.transformer_config
|
||||
content_embedding_config = original_config.params.diffusion_config.params.content_emb_config
|
||||
|
||||
pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location)
|
||||
|
||||
if use_ema:
|
||||
if "ema" in pre_checkpoint:
|
||||
checkpoint = {}
|
||||
for k, v in pre_checkpoint["model"].items():
|
||||
checkpoint[k] = v
|
||||
|
||||
for k, v in pre_checkpoint["ema"].items():
|
||||
# The ema weights are only used on the transformer. To mimic their key as if they came
|
||||
# from the state_dict for the top level model, we prefix with an additional "transformer."
|
||||
# See the source linked in the args.use_ema config for more information.
|
||||
checkpoint[f"transformer.{k}"] = v
|
||||
else:
|
||||
print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.")
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
else:
|
||||
checkpoint = pre_checkpoint["model"]
|
||||
|
||||
del pre_checkpoint
|
||||
|
||||
with init_empty_weights():
|
||||
transformer_model = transformer_model_from_original_config(
|
||||
diffusion_config, transformer_config, content_embedding_config
|
||||
)
|
||||
|
||||
diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint(
|
||||
transformer_model, checkpoint
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file:
|
||||
torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name)
|
||||
del diffusers_transformer_checkpoint
|
||||
del checkpoint
|
||||
load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto")
|
||||
|
||||
print("done loading transformer")
|
||||
|
||||
# done transformer_model
|
||||
|
||||
# text encoder
|
||||
|
||||
print("loading CLIP text encoder")
|
||||
|
||||
clip_name = "openai/clip-vit-base-patch32"
|
||||
|
||||
# The original VQ-Diffusion specifies the pad value by the int used in the
|
||||
# returned tokens. Each model uses `0` as the pad value. The transformers clip api
|
||||
# specifies the pad value via the token before it has been tokenized. The `!` pad
|
||||
# token is the same as padding with the `0` pad value.
|
||||
pad_token = "!"
|
||||
|
||||
tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto")
|
||||
|
||||
assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0
|
||||
|
||||
text_encoder_model = CLIPTextModel.from_pretrained(
|
||||
clip_name,
|
||||
# `CLIPTextModel` does not support device_map="auto"
|
||||
# device_map="auto"
|
||||
)
|
||||
|
||||
print("done loading CLIP text encoder")
|
||||
|
||||
# done text encoder
|
||||
|
||||
# scheduler
|
||||
|
||||
scheduler_model = VQDiffusionScheduler(
|
||||
# the scheduler has the same number of embeddings as the transformer
|
||||
num_vec_classes=transformer_model.num_vector_embeds
|
||||
)
|
||||
|
||||
# done scheduler
|
||||
|
||||
print(f"saving VQ diffusion model, path: {args.dump_path}")
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae_model,
|
||||
transformer=transformer_model,
|
||||
tokenizer=tokenizer_model,
|
||||
text_encoder=text_encoder_model,
|
||||
scheduler=scheduler_model,
|
||||
)
|
||||
pipe.save_pretrained(args.dump_path)
|
||||
|
||||
print("done writing VQ diffusion model")
|
||||
@@ -18,7 +18,7 @@ from .utils import logging
|
||||
|
||||
if is_torch_available():
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -38,6 +38,7 @@ if is_torch_available():
|
||||
PNDMPipeline,
|
||||
RePaintPipeline,
|
||||
ScoreSdeVePipeline,
|
||||
VQDiffusionPipeline,
|
||||
)
|
||||
from .schedulers import (
|
||||
DDIMScheduler,
|
||||
@@ -50,6 +51,7 @@ if is_torch_available():
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
else:
|
||||
|
||||
@@ -16,6 +16,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
|
||||
if is_torch_available():
|
||||
from .attention import Transformer2DModel
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -12,13 +12,30 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
|
||||
|
||||
@dataclass
|
||||
class Transformer2DModelOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
|
||||
Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
|
||||
for the unnoised latent pixels.
|
||||
"""
|
||||
|
||||
sample: torch.FloatTensor
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
@@ -28,6 +45,186 @@ else:
|
||||
xformers = None
|
||||
|
||||
|
||||
class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
|
||||
embeddings) inputs.
|
||||
|
||||
When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
|
||||
transformer action. Finally, reshape to image.
|
||||
|
||||
When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
|
||||
embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
|
||||
classes of unnoised image.
|
||||
|
||||
Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
|
||||
image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
num_vector_embeds (`int`, *optional*):
|
||||
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
|
||||
Includes the class for the masked latent pixel.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
|
||||
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
|
||||
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
|
||||
up to but not more than steps than `num_embeds_ada_norm`.
|
||||
attention_bias (`bool`, *optional*):
|
||||
Configure if the TransformerBlocks' attention should contain a bias parameter.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
|
||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = in_channels is not None
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
if self.is_input_continuous:
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
|
||||
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
self.num_vector_embeds = num_vector_embeds
|
||||
self.num_latent_pixels = self.height * self.width
|
||||
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
activation_fn=activation_fn,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
attention_bias=attention_bias,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
if self.is_input_continuous:
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
|
||||
if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
|
||||
tensor.
|
||||
"""
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, weight = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
logits = self.out(hidden_states)
|
||||
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
|
||||
logits = logits.permute(0, 2, 1)
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
return Transformer2DModelOutput(sample=output)
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
|
||||
class AttentionBlock(nn.Module):
|
||||
"""
|
||||
An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
|
||||
@@ -36,19 +233,19 @@ class AttentionBlock(nn.Module):
|
||||
Uses three q, k, v linear layers to compute attention.
|
||||
|
||||
Parameters:
|
||||
channels (:obj:`int`): The number of channels in the input and output.
|
||||
num_head_channels (:obj:`int`, *optional*):
|
||||
channels (`int`): The number of channels in the input and output.
|
||||
num_head_channels (`int`, *optional*):
|
||||
The number of channels in each head. If None, then `num_heads` = 1.
|
||||
num_groups (:obj:`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (:obj:`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (:obj:`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
|
||||
rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
num_head_channels: Optional[int] = None,
|
||||
num_groups: int = 32,
|
||||
norm_num_groups: int = 32,
|
||||
rescale_output_factor: float = 1.0,
|
||||
eps: float = 1e-5,
|
||||
):
|
||||
@@ -57,7 +254,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
|
||||
self.num_head_size = num_head_channels
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=num_groups, eps=eps, affine=True)
|
||||
self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
|
||||
|
||||
# define q,k,v as linear layers
|
||||
self.query = nn.Linear(channels, channels)
|
||||
@@ -113,107 +310,61 @@ class AttentionBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SpatialTransformer(nn.Module):
|
||||
"""
|
||||
Transformer block for image-like data. First, project the input (aka embedding) and reshape to b, t, d. Then apply
|
||||
standard transformer action. Finally, reshape to image.
|
||||
|
||||
Parameters:
|
||||
in_channels (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
depth (:obj:`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The number of context dimensions to use.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
depth: int = 1,
|
||||
dropout: float = 0.0,
|
||||
num_groups: int = 32,
|
||||
context_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_heads = n_heads
|
||||
self.d_head = d_head
|
||||
self.in_channels = in_channels
|
||||
inner_dim = n_heads * d_head
|
||||
self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim)
|
||||
for d in range(depth)
|
||||
]
|
||||
)
|
||||
|
||||
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def _set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers)
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
# note: if no context is given, cross-attention defaults to self-attention
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm(hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
inner_dim = hidden_states.shape[1]
|
||||
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=context)
|
||||
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
return hidden_states + residual
|
||||
|
||||
|
||||
class BasicTransformerBlock(nn.Module):
|
||||
r"""
|
||||
A basic Transformer block.
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`): The number of channels in the input and output.
|
||||
n_heads (:obj:`int`): The number of heads to use for multi-head attention.
|
||||
d_head (:obj:`int`): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
context_dim (:obj:`int`, *optional*): The size of the context vector for cross attention.
|
||||
gated_ff (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use a gated feed-forward network.
|
||||
checkpoint (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use checkpointing.
|
||||
dim (`int`): The number of channels in the input and output.
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
attention_bias (:
|
||||
obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
n_heads: int,
|
||||
d_head: int,
|
||||
num_attention_heads: int,
|
||||
attention_head_dim: int,
|
||||
dropout=0.0,
|
||||
context_dim: Optional[int] = None,
|
||||
gated_ff: bool = True,
|
||||
checkpoint: bool = True,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
attention_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.attn1 = CrossAttention(
|
||||
query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
query_dim=dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is a self-attention
|
||||
self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.attn2 = CrossAttention(
|
||||
query_dim=dim, context_dim=context_dim, heads=n_heads, dim_head=d_head, dropout=dropout
|
||||
query_dim=dim,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
heads=num_attention_heads,
|
||||
dim_head=attention_head_dim,
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
) # is self-attn if context is none
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
|
||||
# layer norms
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
@@ -245,10 +396,22 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, context=None):
|
||||
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
|
||||
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
@@ -257,20 +420,28 @@ class CrossAttention(nn.Module):
|
||||
A cross attention layer.
|
||||
|
||||
Parameters:
|
||||
query_dim (:obj:`int`): The number of channels in the query.
|
||||
context_dim (:obj:`int`, *optional*):
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
heads (:obj:`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (:obj:`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
bias (`bool`, *optional*, defaults to False):
|
||||
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, query_dim: int, context_dim: Optional[int] = None, heads: int = 8, dim_head: int = 64, dropout: int = 0.0
|
||||
self,
|
||||
query_dim: int,
|
||||
cross_attention_dim: Optional[int] = None,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
context_dim = context_dim if context_dim is not None else query_dim
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
self.heads = heads
|
||||
@@ -280,9 +451,9 @@ class CrossAttention(nn.Module):
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
|
||||
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
@@ -394,23 +565,33 @@ class FeedForward(nn.Module):
|
||||
A feed-forward layer.
|
||||
|
||||
Parameters:
|
||||
dim (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (:obj:`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
glu (:obj:`bool`, *optional*, defaults to :obj:`False`): Whether to use GLU activation.
|
||||
dropout (:obj:`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
dim (`int`): The number of channels in the input.
|
||||
dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim: int, dim_out: Optional[int] = None, mult: int = 4, glu: bool = False, dropout: float = 0.0
|
||||
self,
|
||||
dim: int,
|
||||
dim_out: Optional[int] = None,
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
self.net = nn.ModuleList([])
|
||||
|
||||
if activation_fn == "geglu":
|
||||
geglu = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
geglu = ApproximateGELU(dim, inner_dim)
|
||||
|
||||
self.net = nn.ModuleList([])
|
||||
# project in
|
||||
self.net.append(GEGLU(dim, inner_dim))
|
||||
self.net.append(geglu)
|
||||
# project dropout
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
@@ -428,8 +609,8 @@ class GEGLU(nn.Module):
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
|
||||
Parameters:
|
||||
dim_in (:obj:`int`): The number of channels in the input.
|
||||
dim_out (:obj:`int`): The number of channels in the output.
|
||||
dim_in (`int`): The number of channels in the input.
|
||||
dim_out (`int`): The number of channels in the output.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
@@ -445,3 +626,38 @@ class GEGLU(nn.Module):
|
||||
def forward(self, hidden_states):
|
||||
hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
|
||||
return hidden_states * self.gelu(gate)
|
||||
|
||||
|
||||
class ApproximateGELU(nn.Module):
|
||||
"""
|
||||
The approximate form of Gaussian Error Linear Unit (GELU)
|
||||
|
||||
For more details, see section 2: https://arxiv.org/abs/1606.08415
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.proj(x)
|
||||
return x * torch.sigmoid(1.702 * x)
|
||||
|
||||
|
||||
class AdaLayerNorm(nn.Module):
|
||||
"""
|
||||
Norm layer modified to incorporate timestep embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
self.emb = nn.Embedding(num_embeddings, embedding_dim)
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
|
||||
|
||||
def forward(self, x, timestep):
|
||||
emb = self.linear(self.silu(self.emb(timestep)))
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
@@ -142,7 +142,7 @@ class FlaxBasicTransformerBlock(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
class FlaxSpatialTransformer(nn.Module):
|
||||
class FlaxTransformer2DModel(nn.Module):
|
||||
r"""
|
||||
A Spatial Transformer layer with Gated Linear Unit (GLU) activation function as described in:
|
||||
https://arxiv.org/pdf/1506.02025.pdf
|
||||
|
||||
@@ -126,3 +126,68 @@ class GaussianFourierProjection(nn.Module):
|
||||
else:
|
||||
out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1)
|
||||
return out
|
||||
|
||||
|
||||
class ImagePositionalEmbeddings(nn.Module):
|
||||
"""
|
||||
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
||||
height and width of the latent space.
|
||||
|
||||
For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092
|
||||
|
||||
For VQ-diffusion:
|
||||
|
||||
Output vector embeddings are used as input for the transformer.
|
||||
|
||||
Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE.
|
||||
|
||||
Args:
|
||||
num_embed (`int`):
|
||||
Number of embeddings for the latent pixels embeddings.
|
||||
height (`int`):
|
||||
Height of the latent image i.e. the number of height embeddings.
|
||||
width (`int`):
|
||||
Width of the latent image i.e. the number of width embeddings.
|
||||
embed_dim (`int`):
|
||||
Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embed: int,
|
||||
height: int,
|
||||
width: int,
|
||||
embed_dim: int,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.num_embed = num_embed
|
||||
self.embed_dim = embed_dim
|
||||
|
||||
self.emb = nn.Embedding(self.num_embed, embed_dim)
|
||||
self.height_emb = nn.Embedding(self.height, embed_dim)
|
||||
self.width_emb = nn.Embedding(self.width, embed_dim)
|
||||
|
||||
def forward(self, index):
|
||||
emb = self.emb(index)
|
||||
|
||||
height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height))
|
||||
|
||||
# 1 x H x D -> 1 x H x 1 x D
|
||||
height_emb = height_emb.unsqueeze(2)
|
||||
|
||||
width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width))
|
||||
|
||||
# 1 x W x D -> 1 x 1 x W x D
|
||||
width_emb = width_emb.unsqueeze(1)
|
||||
|
||||
pos_emb = height_emb + width_emb
|
||||
|
||||
# 1 x H x W x D -> 1 x L xD
|
||||
pos_emb = pos_emb.view(1, self.height * self.width, -1)
|
||||
|
||||
emb = emb + pos_emb[:, : emb.shape[1], :]
|
||||
|
||||
return emb
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, SpatialTransformer
|
||||
from .attention import AttentionBlock, Transformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -109,6 +109,19 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
)
|
||||
elif down_block_type == "AttnDownEncoderBlock2D":
|
||||
return AttnDownEncoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
|
||||
def get_up_block(
|
||||
@@ -200,6 +213,17 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
)
|
||||
elif up_block_type == "AttnUpDecoderBlock2D":
|
||||
return AttnUpDecoderBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
|
||||
@@ -249,7 +273,7 @@ class UNetMidBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -325,13 +349,13 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
in_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
in_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=in_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
@@ -374,7 +398,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -427,7 +451,7 @@ class AttnDownBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -506,13 +530,13 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -556,19 +580,22 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
@@ -763,7 +790,7 @@ class AttnDownEncoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1014,7 +1041,7 @@ class AttnUpBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1089,13 +1116,13 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
SpatialTransformer(
|
||||
out_channels,
|
||||
Transformer2DModel(
|
||||
attn_num_head_channels,
|
||||
out_channels // attn_num_head_channels,
|
||||
depth=1,
|
||||
context_dim=cross_attention_dim,
|
||||
num_groups=resnet_groups,
|
||||
in_channels=out_channels,
|
||||
num_layers=1,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
@@ -1145,19 +1172,22 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def create_custom_forward(module, return_dict=None):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
if return_dict is not None:
|
||||
return module(*inputs, return_dict=return_dict)
|
||||
else:
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
@@ -1337,7 +1367,7 @@ class AttnUpDecoderBlock2D(nn.Module):
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
num_groups=resnet_groups,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@
|
||||
import flax.linen as nn
|
||||
import jax.numpy as jnp
|
||||
|
||||
from .attention_flax import FlaxSpatialTransformer
|
||||
from .attention_flax import FlaxTransformer2DModel
|
||||
from .resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -196,7 +196,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
||||
)
|
||||
resnets.append(res_block)
|
||||
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.out_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.out_channels // self.attn_num_head_channels,
|
||||
@@ -326,7 +326,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(self.num_layers):
|
||||
attn_block = FlaxSpatialTransformer(
|
||||
attn_block = FlaxTransformer2DModel(
|
||||
in_channels=self.in_channels,
|
||||
n_heads=self.attn_num_head_channels,
|
||||
d_head=self.in_channels // self.attn_num_head_channels,
|
||||
|
||||
@@ -233,14 +233,16 @@ class VectorQuantizer(nn.Module):
|
||||
# NOTE: due to a bug the beta term was applied to the wrong term. for
|
||||
# backwards compatibility we use the buggy version by default, but you can
|
||||
# specify legacy=False to fix it.
|
||||
def __init__(self, n_e, e_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True):
|
||||
def __init__(
|
||||
self, n_e, vq_embed_dim, beta, remap=None, unknown_index="random", sane_index_shape=False, legacy=True
|
||||
):
|
||||
super().__init__()
|
||||
self.n_e = n_e
|
||||
self.e_dim = e_dim
|
||||
self.vq_embed_dim = vq_embed_dim
|
||||
self.beta = beta
|
||||
self.legacy = legacy
|
||||
|
||||
self.embedding = nn.Embedding(self.n_e, self.e_dim)
|
||||
self.embedding = nn.Embedding(self.n_e, self.vq_embed_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
|
||||
|
||||
self.remap = remap
|
||||
@@ -287,7 +289,7 @@ class VectorQuantizer(nn.Module):
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.e_dim)
|
||||
z_flattened = z.view(-1, self.vq_embed_dim)
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
|
||||
d = (
|
||||
@@ -409,6 +411,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
|
||||
sample_size (`int`, *optional*, defaults to `32`): TODO
|
||||
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
|
||||
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -425,6 +428,7 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
sample_size: int = 32,
|
||||
num_vq_embeddings: int = 256,
|
||||
norm_num_groups: int = 32,
|
||||
vq_embed_dim: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -440,11 +444,11 @@ class VQModel(ModelMixin, ConfigMixin):
|
||||
double_z=False,
|
||||
)
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
self.quantize = VectorQuantizer(
|
||||
num_vq_embeddings, latent_channels, beta=0.25, remap=None, sane_index_shape=False
|
||||
)
|
||||
self.post_quant_conv = torch.nn.Conv2d(latent_channels, latent_channels, 1)
|
||||
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
|
||||
|
||||
self.quant_conv = torch.nn.Conv2d(latent_channels, vq_embed_dim, 1)
|
||||
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
|
||||
self.post_quant_conv = torch.nn.Conv2d(vq_embed_dim, latent_channels, 1)
|
||||
|
||||
# pass init params to Decoder
|
||||
self.decoder = Decoder(
|
||||
|
||||
@@ -21,6 +21,7 @@ if is_torch_available() and is_transformers_available():
|
||||
StableDiffusionInpaintPipelineLegacy,
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from .vq_diffusion import VQDiffusionPipeline
|
||||
|
||||
if is_transformers_available() and is_onnx_available():
|
||||
from .stable_diffusion import (
|
||||
|
||||
1
src/diffusers/pipelines/vq_diffusion/__init__.py
Normal file
1
src/diffusers/pipelines/vq_diffusion/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .pipeline_vq_diffusion import VQDiffusionPipeline
|
||||
253
src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Normal file
253
src/diffusers/pipelines/vq_diffusion/pipeline_vq_diffusion.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import Transformer2DModel, VQModel
|
||||
from diffusers.schedulers.scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from ...utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class VQDiffusionPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
Pipeline for text-to-image generation using VQ Diffusion
|
||||
|
||||
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Args:
|
||||
vqvae ([`VQModel`]):
|
||||
Vector Quantized Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent
|
||||
representations.
|
||||
text_encoder ([`CLIPTextModel`]):
|
||||
Frozen text-encoder. VQ Diffusion uses the text portion of
|
||||
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
|
||||
the [clip-vit-base-patch32](https://huggingface.co/openai/clip-vit-base-patch32) variant.
|
||||
tokenizer (`CLIPTokenizer`):
|
||||
Tokenizer of class
|
||||
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
|
||||
transformer ([`Transformer2DModel`]):
|
||||
Conditional transformer to denoise the encoded image latents.
|
||||
scheduler ([`VQDiffusionScheduler`]):
|
||||
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
vqvae: VQModel
|
||||
text_encoder: CLIPTextModel
|
||||
tokenizer: CLIPTokenizer
|
||||
transformer: Transformer2DModel
|
||||
scheduler: VQDiffusionScheduler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vqvae: VQModel,
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
transformer: Transformer2DModel,
|
||||
scheduler: VQDiffusionScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
vqvae=vqvae,
|
||||
transformer=transformer,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=scheduler,
|
||||
)
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_inference_steps: int = 100,
|
||||
truncation_rate: float = 1.0,
|
||||
num_images_per_prompt: int = 1,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: Optional[int] = 1,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`):
|
||||
The prompt or prompts to guide the image generation.
|
||||
num_inference_steps (`int`, *optional*, defaults to 100):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
truncation_rate (`float`, *optional*, defaults to 1.0 (equivalent to no truncation)):
|
||||
Used to "truncate" the predicted classes for x_0 such that the cumulative probability for a pixel is at
|
||||
most `truncation_rate`. The lowest probabilities that would increase the cumulative probability above
|
||||
`truncation_rate` are set to zero.
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
latents (`torch.FloatTensor` of shape (batch), *optional*):
|
||||
Pre-generated noisy latents to be used as inputs for image generation. Must be valid embedding indices.
|
||||
Can be used to tweak the same generation with different prompts. If not provided, a latents tensor will
|
||||
be generated of completely masked latent pixels.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generated image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipeline_utils.ImagePipelineOutput`] instead of a plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
|
||||
Returns:
|
||||
[`~pipeline_utils.ImagePipelineOutput`] or `tuple`: [`~ pipeline_utils.ImagePipelineOutput `] if
|
||||
`return_dict` is True, otherwise a `tuple. When returning a tuple, the first element is a list with the
|
||||
generated images.
|
||||
"""
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
if (callback_steps is None) or (
|
||||
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
|
||||
):
|
||||
raise ValueError(
|
||||
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
|
||||
f" {type(callback_steps)}."
|
||||
)
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0]
|
||||
|
||||
# NOTE: This additional step of normalizing the text embeddings is from VQ-Diffusion.
|
||||
# While CLIP does normalize the pooled output of the text transformer when combining
|
||||
# the image and text embeddings, CLIP does not directly normalize the last hidden state.
|
||||
#
|
||||
# CLIP normalizing the pooled output.
|
||||
# https://github.com/huggingface/transformers/blob/d92e22d1f28324f513f3080e5c47c071a3916721/src/transformers/models/clip/modeling_clip.py#L1052-L1053
|
||||
text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)
|
||||
|
||||
# duplicate text embeddings for each generation per prompt
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
# get the initial completely masked latents unless the user supplied it
|
||||
|
||||
latents_shape = (batch_size, self.transformer.num_latent_pixels)
|
||||
if latents is None:
|
||||
mask_class = self.transformer.num_vector_embeds - 1
|
||||
latents = torch.full(latents_shape, mask_class).to(self.device)
|
||||
else:
|
||||
if latents.shape != latents_shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
|
||||
if (latents < 0).any() or (latents >= self.transformer.num_vector_embeds).any():
|
||||
raise ValueError(
|
||||
"Unexpected latents value(s). All latents be valid embedding indices i.e. in the range 0,"
|
||||
f" {self.transformer.num_vector_embeds - 1} (inclusive)."
|
||||
)
|
||||
latents = latents.to(self.device)
|
||||
|
||||
# set timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=self.device)
|
||||
|
||||
timesteps_tensor = self.scheduler.timesteps.to(self.device)
|
||||
|
||||
sample = latents
|
||||
|
||||
for i, t in enumerate(self.progress_bar(timesteps_tensor)):
|
||||
# predict the un-noised image
|
||||
# model_output == `log_p_x_0`
|
||||
model_output = self.transformer(sample, encoder_hidden_states=text_embeddings, timestep=t).sample
|
||||
|
||||
model_output = self.truncate(model_output, truncation_rate)
|
||||
|
||||
# remove `log(0)`'s (`-inf`s)
|
||||
model_output = model_output.clamp(-70)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
sample = self.scheduler.step(model_output, timestep=t, sample=sample, generator=generator).prev_sample
|
||||
|
||||
# call the callback, if provided
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, sample)
|
||||
|
||||
embedding_channels = self.vqvae.config.vq_embed_dim
|
||||
embeddings_shape = (batch_size, self.transformer.height, self.transformer.width, embedding_channels)
|
||||
embeddings = self.vqvae.quantize.get_codebook_entry(sample, shape=embeddings_shape)
|
||||
image = self.vqvae.decode(embeddings, force_not_quantize=True).sample
|
||||
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
|
||||
def truncate(self, log_p_x_0: torch.FloatTensor, truncation_rate: float) -> torch.FloatTensor:
|
||||
"""
|
||||
Truncates log_p_x_0 such that for each column vector, the total cumulative probability is `truncation_rate` The
|
||||
lowest probabilities that would increase the cumulative probability above `truncation_rate` are set to zero.
|
||||
"""
|
||||
sorted_log_p_x_0, indices = torch.sort(log_p_x_0, 1, descending=True)
|
||||
sorted_p_x_0 = torch.exp(sorted_log_p_x_0)
|
||||
keep_mask = sorted_p_x_0.cumsum(dim=1) < truncation_rate
|
||||
|
||||
# Ensure that at least the largest probability is not zeroed out
|
||||
all_true = torch.full_like(keep_mask[:, 0:1, :], True)
|
||||
keep_mask = torch.cat((all_true, keep_mask), dim=1)
|
||||
keep_mask = keep_mask[:, :-1, :]
|
||||
|
||||
keep_mask = keep_mask.gather(1, indices.argsort(1))
|
||||
|
||||
rv = log_p_x_0.clone()
|
||||
|
||||
rv[~keep_mask] = -torch.inf # -inf = log(0)
|
||||
|
||||
return rv
|
||||
@@ -28,6 +28,7 @@ if is_torch_available():
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
else:
|
||||
from ..utils.dummy_pt_objects import * # noqa F403
|
||||
|
||||
|
||||
494
src/diffusers/schedulers/scheduling_vq_diffusion.py
Normal file
494
src/diffusers/schedulers/scheduling_vq_diffusion.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# Copyright 2022 Microsoft and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class VQDiffusionSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
Computed sample x_{t-1} of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
"""
|
||||
|
||||
prev_sample: torch.LongTensor
|
||||
|
||||
|
||||
def index_to_log_onehot(x: torch.LongTensor, num_classes: int) -> torch.FloatTensor:
|
||||
"""
|
||||
Convert batch of vector of class indices into batch of log onehot vectors
|
||||
|
||||
Args:
|
||||
x (`torch.LongTensor` of shape `(batch size, vector length)`):
|
||||
Batch of class indices
|
||||
|
||||
num_classes (`int`):
|
||||
number of classes to be used for the onehot vectors
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes, vector length)`:
|
||||
Log onehot vectors
|
||||
"""
|
||||
x_onehot = F.one_hot(x, num_classes)
|
||||
x_onehot = x_onehot.permute(0, 2, 1)
|
||||
log_x = torch.log(x_onehot.float().clamp(min=1e-30))
|
||||
return log_x
|
||||
|
||||
|
||||
def gumbel_noised(logits: torch.FloatTensor, generator: Optional[torch.Generator]) -> torch.FloatTensor:
|
||||
"""
|
||||
Apply gumbel noise to `logits`
|
||||
"""
|
||||
uniform = torch.rand(logits.shape, device=logits.device, generator=generator)
|
||||
gumbel_noise = -torch.log(-torch.log(uniform + 1e-30) + 1e-30)
|
||||
noised = gumbel_noise + logits
|
||||
return noised
|
||||
|
||||
|
||||
def alpha_schedules(num_diffusion_timesteps: int, alpha_cum_start=0.99999, alpha_cum_end=0.000009):
|
||||
"""
|
||||
Cumulative and non-cumulative alpha schedules.
|
||||
|
||||
See section 4.1.
|
||||
"""
|
||||
att = (
|
||||
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (alpha_cum_end - alpha_cum_start)
|
||||
+ alpha_cum_start
|
||||
)
|
||||
att = np.concatenate(([1], att))
|
||||
at = att[1:] / att[:-1]
|
||||
att = np.concatenate((att[1:], [1]))
|
||||
return at, att
|
||||
|
||||
|
||||
def gamma_schedules(num_diffusion_timesteps: int, gamma_cum_start=0.000009, gamma_cum_end=0.99999):
|
||||
"""
|
||||
Cumulative and non-cumulative gamma schedules.
|
||||
|
||||
See section 4.1.
|
||||
"""
|
||||
ctt = (
|
||||
np.arange(0, num_diffusion_timesteps) / (num_diffusion_timesteps - 1) * (gamma_cum_end - gamma_cum_start)
|
||||
+ gamma_cum_start
|
||||
)
|
||||
ctt = np.concatenate(([0], ctt))
|
||||
one_minus_ctt = 1 - ctt
|
||||
one_minus_ct = one_minus_ctt[1:] / one_minus_ctt[:-1]
|
||||
ct = 1 - one_minus_ct
|
||||
ctt = np.concatenate((ctt[1:], [0]))
|
||||
return ct, ctt
|
||||
|
||||
|
||||
class VQDiffusionScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
The VQ-diffusion transformer outputs predicted probabilities of the initial unnoised image.
|
||||
|
||||
The VQ-diffusion scheduler converts the transformer's output into a sample for the unnoised image at the previous
|
||||
diffusion timestep.
|
||||
|
||||
[`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
|
||||
function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
|
||||
[`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and
|
||||
[`~ConfigMixin.from_config`] functions.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2111.14822
|
||||
|
||||
Args:
|
||||
num_vec_classes (`int`):
|
||||
The number of classes of the vector embeddings of the latent pixels. Includes the class for the masked
|
||||
latent pixel.
|
||||
|
||||
num_train_timesteps (`int`):
|
||||
Number of diffusion steps used to train the model.
|
||||
|
||||
alpha_cum_start (`float`):
|
||||
The starting cumulative alpha value.
|
||||
|
||||
alpha_cum_end (`float`):
|
||||
The ending cumulative alpha value.
|
||||
|
||||
gamma_cum_start (`float`):
|
||||
The starting cumulative gamma value.
|
||||
|
||||
gamma_cum_end (`float`):
|
||||
The ending cumulative gamma value.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_vec_classes: int,
|
||||
num_train_timesteps: int = 100,
|
||||
alpha_cum_start: float = 0.99999,
|
||||
alpha_cum_end: float = 0.000009,
|
||||
gamma_cum_start: float = 0.000009,
|
||||
gamma_cum_end: float = 0.99999,
|
||||
):
|
||||
self.num_embed = num_vec_classes
|
||||
|
||||
# By convention, the index for the mask class is the last class index
|
||||
self.mask_class = self.num_embed - 1
|
||||
|
||||
at, att = alpha_schedules(num_train_timesteps, alpha_cum_start=alpha_cum_start, alpha_cum_end=alpha_cum_end)
|
||||
ct, ctt = gamma_schedules(num_train_timesteps, gamma_cum_start=gamma_cum_start, gamma_cum_end=gamma_cum_end)
|
||||
|
||||
num_non_mask_classes = self.num_embed - 1
|
||||
bt = (1 - at - ct) / num_non_mask_classes
|
||||
btt = (1 - att - ctt) / num_non_mask_classes
|
||||
|
||||
at = torch.tensor(at.astype("float64"))
|
||||
bt = torch.tensor(bt.astype("float64"))
|
||||
ct = torch.tensor(ct.astype("float64"))
|
||||
log_at = torch.log(at)
|
||||
log_bt = torch.log(bt)
|
||||
log_ct = torch.log(ct)
|
||||
|
||||
att = torch.tensor(att.astype("float64"))
|
||||
btt = torch.tensor(btt.astype("float64"))
|
||||
ctt = torch.tensor(ctt.astype("float64"))
|
||||
log_cumprod_at = torch.log(att)
|
||||
log_cumprod_bt = torch.log(btt)
|
||||
log_cumprod_ct = torch.log(ctt)
|
||||
|
||||
self.log_at = log_at.float()
|
||||
self.log_bt = log_bt.float()
|
||||
self.log_ct = log_ct.float()
|
||||
self.log_cumprod_at = log_cumprod_at.float()
|
||||
self.log_cumprod_bt = log_cumprod_bt.float()
|
||||
self.log_cumprod_ct = log_cumprod_ct.float()
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
|
||||
device (`str` or `torch.device`):
|
||||
device to place the timesteps and the diffusion process parameters (alpha, beta, gamma) on.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
timesteps = np.arange(0, self.num_inference_steps)[::-1].copy()
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
self.log_at = self.log_at.to(device)
|
||||
self.log_bt = self.log_bt.to(device)
|
||||
self.log_ct = self.log_ct.to(device)
|
||||
self.log_cumprod_at = self.log_cumprod_at.to(device)
|
||||
self.log_cumprod_bt = self.log_cumprod_bt.to(device)
|
||||
self.log_cumprod_ct = self.log_cumprod_ct.to(device)
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: torch.long,
|
||||
sample: torch.LongTensor,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[VQDiffusionSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep via the reverse transition distribution i.e. Equation (11). See the
|
||||
docstring for `self.q_posterior` for more in depth docs on how Equation (11) is computed.
|
||||
|
||||
Args:
|
||||
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
||||
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
||||
prediction for the masked class as the initial unnoised image cannot be masked.
|
||||
|
||||
t (`torch.long`):
|
||||
The timestep that determines which transition matrices are used.
|
||||
|
||||
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`
|
||||
|
||||
generator: (`torch.Generator` or None):
|
||||
RNG for the noise applied to p(x_{t-1} | x_t) before it is sampled from.
|
||||
|
||||
return_dict (`bool`):
|
||||
option for returning tuple rather than VQDiffusionSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.VQDiffusionSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`.
|
||||
When returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
if timestep == 0:
|
||||
log_p_x_t_min_1 = model_output
|
||||
else:
|
||||
log_p_x_t_min_1 = self.q_posterior(model_output, sample, timestep)
|
||||
|
||||
log_p_x_t_min_1 = gumbel_noised(log_p_x_t_min_1, generator)
|
||||
|
||||
x_t_min_1 = log_p_x_t_min_1.argmax(dim=1)
|
||||
|
||||
if not return_dict:
|
||||
return (x_t_min_1,)
|
||||
|
||||
return VQDiffusionSchedulerOutput(prev_sample=x_t_min_1)
|
||||
|
||||
def q_posterior(self, log_p_x_0, x_t, t):
|
||||
"""
|
||||
Calculates the log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
|
||||
|
||||
Instead of directly computing equation (11), we use Equation (5) to restate Equation (11) in terms of only
|
||||
forward probabilities.
|
||||
|
||||
Equation (11) stated in terms of forward probabilities via Equation (5):
|
||||
|
||||
Where:
|
||||
- the sum is over x_0 = {C_0 ... C_{k-1}} (classes for x_0)
|
||||
|
||||
p(x_{t-1} | x_t) = sum( q(x_t | x_{t-1}) * q(x_{t-1} | x_0) * p(x_0) / q(x_t | x_0) )
|
||||
|
||||
Args:
|
||||
log_p_x_0: (`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`):
|
||||
The log probabilities for the predicted classes of the initial latent pixels. Does not include a
|
||||
prediction for the masked class as the initial unnoised image cannot be masked.
|
||||
|
||||
x_t: (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`
|
||||
|
||||
t (torch.Long):
|
||||
The timestep that determines which transition matrix is used.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`:
|
||||
The log probabilities for the predicted classes of the image at timestep `t-1`. I.e. Equation (11).
|
||||
"""
|
||||
log_onehot_x_t = index_to_log_onehot(x_t, self.num_embed)
|
||||
|
||||
log_q_x_t_given_x_0 = self.log_Q_t_transitioning_to_known_class(
|
||||
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=True
|
||||
)
|
||||
|
||||
log_q_t_given_x_t_min_1 = self.log_Q_t_transitioning_to_known_class(
|
||||
t=t, x_t=x_t, log_onehot_x_t=log_onehot_x_t, cumulative=False
|
||||
)
|
||||
|
||||
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0)
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
|
||||
q = log_p_x_0 - log_q_x_t_given_x_0
|
||||
|
||||
# sum_0 = p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}), ... ,
|
||||
# sum_n = p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) + ... + p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1})
|
||||
q_log_sum_exp = torch.logsumexp(q, dim=1, keepdim=True)
|
||||
|
||||
# p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0 ... p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0 ... p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n
|
||||
q = q - q_log_sum_exp
|
||||
|
||||
# (p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# (p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1} ... (p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}
|
||||
# c_cumulative_{t-1} ... c_cumulative_{t-1}
|
||||
q = self.apply_cumulative_transitions(q, t - 1)
|
||||
|
||||
# ((p_0(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_0 ... ((p_n(x_0=C_0 | x_t) / q(x_t | x_0=C_0) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_0) * sum_n
|
||||
# . . .
|
||||
# . . .
|
||||
# . . .
|
||||
# ((p_0(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_0) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_0 ... ((p_n(x_0=C_{k-1} | x_t) / q(x_t | x_0=C_{k-1}) / sum_n) * a_cumulative_{t-1} + b_cumulative_{t-1}) * q(x_t | x_{t-1}=C_{k-1}) * sum_n
|
||||
# c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0 ... c_cumulative_{t-1} * q(x_t | x_{t-1}=C_k) * sum_0
|
||||
log_p_x_t_min_1 = q + log_q_t_given_x_t_min_1 + q_log_sum_exp
|
||||
|
||||
# For each column, there are two possible cases.
|
||||
#
|
||||
# Where:
|
||||
# - sum(p_n(x_0))) is summing over all classes for x_0
|
||||
# - C_i is the class transitioning from (not to be confused with c_t and c_cumulative_t being used for gamma's)
|
||||
# - C_j is the class transitioning to
|
||||
#
|
||||
# 1. x_t is masked i.e. x_t = c_k
|
||||
#
|
||||
# Simplifying the expression, the column vector is:
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# (c_t / c_cumulative_t) * (a_cumulative_{t-1} * p_n(x_0 = C_i | x_t) + b_cumulative_{t-1} * sum(p_n(x_0)))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# (c_cumulative_{t-1} / c_cumulative_t) * sum(p_n(x_0))
|
||||
#
|
||||
# From equation (11) stated in terms of forward probabilities, the last row is trivially verified.
|
||||
#
|
||||
# For the other rows, we can state the equation as ...
|
||||
#
|
||||
# (c_t / c_cumulative_t) * [b_cumulative_{t-1} * p(x_0=c_0) + ... + (a_cumulative_{t-1} + b_cumulative_{t-1}) * p(x_0=C_i) + ... + b_cumulative_{k-1} * p(x_0=c_{k-1})]
|
||||
#
|
||||
# This verifies the other rows.
|
||||
#
|
||||
# 2. x_t is not masked
|
||||
#
|
||||
# Simplifying the expression, there are two cases for the rows of the column vector, where C_j = C_i and where C_j != C_i:
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# C_j != C_i: b_t * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / b_cumulative_t) * p_n(x_0 = C_i) + ... + (b_cumulative_{t-1} / (a_cumulative_t + b_cumulative_t)) * p_n(c_0=C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# C_j = C_i: (a_t + b_t) * ((b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_0) + ... + ((a_cumulative_{t-1} + b_cumulative_{t-1}) / (a_cumulative_t + b_cumulative_t)) * p_n(x_0 = C_i = C_j) + ... + (b_cumulative_{t-1} / b_cumulative_t) * p_n(x_0 = c_{k-1}))
|
||||
# .
|
||||
# .
|
||||
# .
|
||||
# 0
|
||||
#
|
||||
# The last row is trivially verified. The other rows can be verified by directly expanding equation (11) stated in terms of forward probabilities.
|
||||
return log_p_x_t_min_1
|
||||
|
||||
def log_Q_t_transitioning_to_known_class(
|
||||
self, *, t: torch.int, x_t: torch.LongTensor, log_onehot_x_t: torch.FloatTensor, cumulative: bool
|
||||
):
|
||||
"""
|
||||
Returns the log probabilities of the rows from the (cumulative or non-cumulative) transition matrix for each
|
||||
latent pixel in `x_t`.
|
||||
|
||||
See equation (7) for the complete non-cumulative transition matrix. The complete cumulative transition matrix
|
||||
is the same structure except the parameters (alpha, beta, gamma) are the cumulative analogs.
|
||||
|
||||
Args:
|
||||
t (torch.Long):
|
||||
The timestep that determines which transition matrix is used.
|
||||
|
||||
x_t (`torch.LongTensor` of shape `(batch size, num latent pixels)`):
|
||||
The classes of each latent pixel at time `t`.
|
||||
|
||||
log_onehot_x_t (`torch.FloatTensor` of shape `(batch size, num classes, num latent pixels)`):
|
||||
The log one-hot vectors of `x_t`
|
||||
|
||||
cumulative (`bool`):
|
||||
If cumulative is `False`, we use the single step transition matrix `t-1`->`t`. If cumulative is `True`,
|
||||
we use the cumulative transition matrix `0`->`t`.
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor` of shape `(batch size, num classes - 1, num latent pixels)`:
|
||||
Each _column_ of the returned matrix is a _row_ of log probabilities of the complete probability
|
||||
transition matrix.
|
||||
|
||||
When non cumulative, returns `self.num_classes - 1` rows because the initial latent pixel cannot be
|
||||
masked.
|
||||
|
||||
Where:
|
||||
- `q_n` is the probability distribution for the forward process of the `n`th latent pixel.
|
||||
- C_0 is a class of a latent pixel embedding
|
||||
- C_k is the class of the masked latent pixel
|
||||
|
||||
non-cumulative result (omitting logarithms):
|
||||
```
|
||||
q_0(x_t | x_{t-1} = C_0) ... q_n(x_t | x_{t-1} = C_0)
|
||||
. . .
|
||||
. . .
|
||||
. . .
|
||||
q_0(x_t | x_{t-1} = C_k) ... q_n(x_t | x_{t-1} = C_k)
|
||||
```
|
||||
|
||||
cumulative result (omitting logarithms):
|
||||
```
|
||||
q_0_cumulative(x_t | x_0 = C_0) ... q_n_cumulative(x_t | x_0 = C_0)
|
||||
. . .
|
||||
. . .
|
||||
. . .
|
||||
q_0_cumulative(x_t | x_0 = C_{k-1}) ... q_n_cumulative(x_t | x_0 = C_{k-1})
|
||||
```
|
||||
"""
|
||||
if cumulative:
|
||||
a = self.log_cumprod_at[t]
|
||||
b = self.log_cumprod_bt[t]
|
||||
c = self.log_cumprod_ct[t]
|
||||
else:
|
||||
a = self.log_at[t]
|
||||
b = self.log_bt[t]
|
||||
c = self.log_ct[t]
|
||||
|
||||
if not cumulative:
|
||||
# The values in the onehot vector can also be used as the logprobs for transitioning
|
||||
# from masked latent pixels. If we are not calculating the cumulative transitions,
|
||||
# we need to save these vectors to be re-appended to the final matrix so the values
|
||||
# aren't overwritten.
|
||||
#
|
||||
# `P(x_t!=mask|x_{t-1=mask}) = 0` and 0 will be the value of the last row of the onehot vector
|
||||
# if x_t is not masked
|
||||
#
|
||||
# `P(x_t=mask|x_{t-1=mask}) = 1` and 1 will be the value of the last row of the onehot vector
|
||||
# if x_t is masked
|
||||
log_onehot_x_t_transitioning_from_masked = log_onehot_x_t[:, -1, :].unsqueeze(1)
|
||||
|
||||
# `index_to_log_onehot` will add onehot vectors for masked pixels,
|
||||
# so the default one hot matrix has one too many rows. See the doc string
|
||||
# for an explanation of the dimensionality of the returned matrix.
|
||||
log_onehot_x_t = log_onehot_x_t[:, :-1, :]
|
||||
|
||||
# this is a cheeky trick to produce the transition probabilities using log one-hot vectors.
|
||||
#
|
||||
# Don't worry about what values this sets in the columns that mark transitions
|
||||
# to masked latent pixels. They are overwrote later with the `mask_class_mask`.
|
||||
#
|
||||
# Looking at the below logspace formula in non-logspace, each value will evaluate to either
|
||||
# `1 * a + b = a + b` where `log_Q_t` has the one hot value in the column
|
||||
# or
|
||||
# `0 * a + b = b` where `log_Q_t` has the 0 values in the column.
|
||||
#
|
||||
# See equation 7 for more details.
|
||||
log_Q_t = (log_onehot_x_t + a).logaddexp(b)
|
||||
|
||||
# The whole column of each masked pixel is `c`
|
||||
mask_class_mask = x_t == self.mask_class
|
||||
mask_class_mask = mask_class_mask.unsqueeze(1).expand(-1, self.num_embed - 1, -1)
|
||||
log_Q_t[mask_class_mask] = c
|
||||
|
||||
if not cumulative:
|
||||
log_Q_t = torch.cat((log_Q_t, log_onehot_x_t_transitioning_from_masked), dim=1)
|
||||
|
||||
return log_Q_t
|
||||
|
||||
def apply_cumulative_transitions(self, q, t):
|
||||
bsz = q.shape[0]
|
||||
a = self.log_cumprod_at[t]
|
||||
b = self.log_cumprod_bt[t]
|
||||
c = self.log_cumprod_ct[t]
|
||||
|
||||
num_latent_pixels = q.shape[2]
|
||||
c = c.expand(bsz, 1, num_latent_pixels)
|
||||
|
||||
q = (q + a).logaddexp(b)
|
||||
q = torch.cat((q, c), dim=1)
|
||||
|
||||
return q
|
||||
@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UNet1DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -257,6 +272,21 @@ class ScoreSdeVePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DDIMScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -407,6 +437,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class EMAModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
0
tests/pipelines/vq_diffusion/__init__.py
Normal file
0
tests/pipelines/vq_diffusion/__init__.py
Normal file
175
tests/pipelines/vq_diffusion/test_vq_diffusion.py
Normal file
175
tests/pipelines/vq_diffusion/test_vq_diffusion.py
Normal file
@@ -0,0 +1,175 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel
|
||||
from diffusers.utils import load_image, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class VQDiffusionPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def num_embed(self):
|
||||
return 12
|
||||
|
||||
@property
|
||||
def num_embeds_ada_norm(self):
|
||||
return 12
|
||||
|
||||
@property
|
||||
def dummy_vqvae(self):
|
||||
torch.manual_seed(0)
|
||||
model = VQModel(
|
||||
block_out_channels=[32, 64],
|
||||
in_channels=3,
|
||||
out_channels=3,
|
||||
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
|
||||
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
|
||||
latent_channels=3,
|
||||
num_vq_embeddings=self.num_embed,
|
||||
vq_embed_dim=3,
|
||||
)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=32,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModel(config)
|
||||
|
||||
@property
|
||||
def dummy_transformer(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
height = 12
|
||||
width = 12
|
||||
|
||||
model_kwargs = {
|
||||
"attention_bias": True,
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": height * width,
|
||||
"num_attention_heads": 1,
|
||||
"num_vector_embeds": self.num_embed,
|
||||
"num_embeds_ada_norm": self.num_embeds_ada_norm,
|
||||
"norm_num_groups": 32,
|
||||
"sample_size": width,
|
||||
"activation_fn": "geglu-approximate",
|
||||
}
|
||||
|
||||
model = Transformer2DModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
def test_vq_diffusion(self):
|
||||
device = "cpu"
|
||||
|
||||
vqvae = self.dummy_vqvae
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
transformer = self.dummy_transformer
|
||||
scheduler = VQDiffusionScheduler(self.num_embed)
|
||||
|
||||
pipe = VQDiffusionPipeline(
|
||||
vqvae=vqvae, text_encoder=text_encoder, tokenizer=tokenizer, transformer=transformer, scheduler=scheduler
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "teddy bear playing in the pool"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe([prompt], generator=generator, num_inference_steps=2, output_type="np")
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = pipe(
|
||||
[prompt], generator=generator, output_type="np", return_dict=False, num_inference_steps=2
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 24, 24, 3)
|
||||
|
||||
expected_slice = np.array([0.6583, 0.6410, 0.5325, 0.5635, 0.5563, 0.4234, 0.6008, 0.5491, 0.4880])
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class VQDiffusionPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_vq_diffusion(self):
|
||||
expected_image = load_image(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/vq_diffusion/teddy_bear_pool.png"
|
||||
)
|
||||
expected_image = np.array(expected_image, dtype=np.float32) / 255.0
|
||||
|
||||
pipeline = VQDiffusionPipeline.from_pretrained("microsoft/vq-diffusion-ithq")
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipeline(
|
||||
"teddy bear playing in the pool",
|
||||
truncation_rate=0.86,
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
@@ -18,8 +18,9 @@ import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.models.attention import AttentionBlock, SpatialTransformer
|
||||
from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU, AttentionBlock, Transformer2DModel
|
||||
from diffusers.models.embeddings import get_timestep_embedding
|
||||
from diffusers.models.resnet import Downsample2D, Upsample2D
|
||||
from diffusers.utils import torch_device
|
||||
@@ -235,7 +236,7 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
num_head_channels=1,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
@@ -259,7 +260,7 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
channels=512,
|
||||
rescale_output_factor=1.0,
|
||||
eps=1e-6,
|
||||
num_groups=32,
|
||||
norm_num_groups=32,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = attentionBlock(sample)
|
||||
@@ -273,22 +274,22 @@ class AttentionBlockTests(unittest.TestCase):
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
|
||||
class SpatialTransformerTests(unittest.TestCase):
|
||||
class Transformer2DModelTests(unittest.TestCase):
|
||||
def test_spatial_transformer_default(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = SpatialTransformer(
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=32,
|
||||
n_heads=1,
|
||||
d_head=32,
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
context_dim=None,
|
||||
cross_attention_dim=None,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample)
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -298,22 +299,22 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_context_dim(self):
|
||||
def test_spatial_transformer_cross_attention_dim(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = SpatialTransformer(
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=64,
|
||||
n_heads=2,
|
||||
d_head=32,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
context_dim=64,
|
||||
cross_attention_dim=64,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
context = torch.randn(1, 4, 64).to(torch_device)
|
||||
attention_scores = spatial_transformer_block(sample, context)
|
||||
attention_scores = spatial_transformer_block(sample, context).sample
|
||||
|
||||
assert attention_scores.shape == (1, 64, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -323,6 +324,44 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_timestep(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_embeds_ada_norm = 5
|
||||
|
||||
sample = torch.randn(1, 64, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
in_channels=64,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=32,
|
||||
dropout=0.0,
|
||||
cross_attention_dim=64,
|
||||
num_embeds_ada_norm=num_embeds_ada_norm,
|
||||
).to(torch_device)
|
||||
with torch.no_grad():
|
||||
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
|
||||
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
|
||||
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
|
||||
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample
|
||||
|
||||
assert attention_scores_1.shape == (1, 64, 64, 64)
|
||||
assert attention_scores_2.shape == (1, 64, 64, 64)
|
||||
|
||||
output_slice_1 = attention_scores_1[0, -1, -3:, -3:]
|
||||
output_slice_2 = attention_scores_2[0, -1, -3:, -3:]
|
||||
|
||||
expected_slice_1 = torch.tensor(
|
||||
[-0.1874, -0.9704, -1.4290, -1.3357, 1.5138, 0.3036, -0.0976, -1.1667, 0.1283], device=torch_device
|
||||
)
|
||||
expected_slice_2 = torch.tensor(
|
||||
[-0.3493, -1.0924, -1.6161, -1.5016, 1.4245, 0.1367, -0.2526, -1.3109, -0.0547], device=torch_device
|
||||
)
|
||||
|
||||
assert torch.allclose(output_slice_1.flatten(), expected_slice_1, atol=1e-3)
|
||||
assert torch.allclose(output_slice_2.flatten(), expected_slice_2, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_dropout(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
@@ -330,18 +369,18 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
|
||||
sample = torch.randn(1, 32, 64, 64).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
SpatialTransformer(
|
||||
Transformer2DModel(
|
||||
in_channels=32,
|
||||
n_heads=2,
|
||||
d_head=16,
|
||||
num_attention_heads=2,
|
||||
attention_head_dim=16,
|
||||
dropout=0.3,
|
||||
context_dim=None,
|
||||
cross_attention_dim=None,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample)
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, 32, 64, 64)
|
||||
output_slice = attention_scores[0, -1, -3:, -3:]
|
||||
@@ -350,3 +389,107 @@ class SpatialTransformerTests(unittest.TestCase):
|
||||
[-1.2448, -0.0190, -0.9471, -1.5140, 0.7069, -1.0144, -2.1077, 0.9099, -1.0091], device=torch_device
|
||||
)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "MPS does not support float64")
|
||||
def test_spatial_transformer_discrete(self):
|
||||
torch.manual_seed(0)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(0)
|
||||
|
||||
num_embed = 5
|
||||
|
||||
sample = torch.randint(0, num_embed, (1, 32)).to(torch_device)
|
||||
spatial_transformer_block = (
|
||||
Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
num_vector_embeds=num_embed,
|
||||
sample_size=16,
|
||||
)
|
||||
.to(torch_device)
|
||||
.eval()
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
attention_scores = spatial_transformer_block(sample).sample
|
||||
|
||||
assert attention_scores.shape == (1, num_embed - 1, 32)
|
||||
|
||||
output_slice = attention_scores[0, -2:, -3:]
|
||||
|
||||
expected_slice = torch.tensor([-0.8957, -1.8370, -1.3390, -0.9152, -0.5187, -1.1702], device=torch_device)
|
||||
assert torch.allclose(output_slice.flatten(), expected_slice, atol=1e-3)
|
||||
|
||||
def test_spatial_transformer_default_norm_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(num_attention_heads=1, attention_head_dim=32, in_channels=32)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == nn.LayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == nn.LayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
|
||||
|
||||
def test_spatial_transformer_ada_norm_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
num_embeds_ada_norm=5,
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm1.__class__ == AdaLayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm2.__class__ == AdaLayerNorm
|
||||
assert spatial_transformer_block.transformer_blocks[0].norm3.__class__ == nn.LayerNorm
|
||||
|
||||
def test_spatial_transformer_default_ff_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == GEGLU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
|
||||
# First dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
|
||||
# NOTE: inner_dim * 2 because GEGLU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim * 2
|
||||
|
||||
# Second dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
|
||||
|
||||
def test_spatial_transformer_geglu_approx_ff_layers(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1,
|
||||
attention_head_dim=32,
|
||||
in_channels=32,
|
||||
activation_fn="geglu-approximate",
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].__class__ == ApproximateGELU
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[1].__class__ == nn.Dropout
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].__class__ == nn.Linear
|
||||
|
||||
dim = 32
|
||||
inner_dim = 128
|
||||
|
||||
# First dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.in_features == dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[0].proj.out_features == inner_dim
|
||||
|
||||
# Second dimension change
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].in_features == inner_dim
|
||||
assert spatial_transformer_block.transformer_blocks[0].ff.net[2].out_features == dim
|
||||
|
||||
def test_spatial_transformer_attention_bias(self):
|
||||
spatial_transformer_block = Transformer2DModel(
|
||||
num_attention_heads=1, attention_head_dim=32, in_channels=32, attention_bias=True
|
||||
)
|
||||
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_q.bias is not None
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_k.bias is not None
|
||||
assert spatial_transformer_block.transformer_blocks[0].attn1.to_v.bias is not None
|
||||
|
||||
@@ -19,6 +19,7 @@ from typing import Dict, List, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers import (
|
||||
DDIMScheduler,
|
||||
@@ -29,6 +30,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from diffusers.utils import torch_device
|
||||
|
||||
@@ -85,12 +87,18 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
time_step = float(time_step)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config(**config)
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, time_step)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
@@ -122,12 +130,18 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
time_step = float(time_step)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, time_step)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
new_scheduler = scheduler_class.from_config(tmpdirname)
|
||||
@@ -154,15 +168,21 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
num_inference_steps = kwargs.pop("num_inference_steps", None)
|
||||
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
timestep = 1
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timestep = 1
|
||||
if scheduler_class in (EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, LMSDiscreteScheduler):
|
||||
timestep = float(timestep)
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
scheduler.save_config(tmpdirname)
|
||||
@@ -200,8 +220,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep_0)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -255,8 +281,14 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
if scheduler_class == VQDiffusionScheduler:
|
||||
num_vec_classes = scheduler_config["num_vec_classes"]
|
||||
sample = self.dummy_sample(num_vec_classes)
|
||||
model = self.dummy_model(num_vec_classes)
|
||||
residual = model(sample, timestep)
|
||||
else:
|
||||
sample = self.dummy_sample
|
||||
residual = 0.1 * sample
|
||||
|
||||
if num_inference_steps is not None and hasattr(scheduler, "set_timesteps"):
|
||||
scheduler.set_timesteps(num_inference_steps)
|
||||
@@ -284,22 +316,26 @@ class SchedulerCommonTest(unittest.TestCase):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "init_noise_sigma"),
|
||||
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample, timestep)`",
|
||||
)
|
||||
|
||||
if scheduler_class != VQDiffusionScheduler:
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "init_noise_sigma"),
|
||||
f"{scheduler_class} does not implement a required attribute `init_noise_sigma`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "scale_model_input"),
|
||||
f"{scheduler_class} does not implement a required class method `scale_model_input(sample,"
|
||||
" timestep)`",
|
||||
)
|
||||
self.assertTrue(
|
||||
hasattr(scheduler, "step"),
|
||||
f"{scheduler_class} does not implement a required class method `step(...)`",
|
||||
)
|
||||
|
||||
sample = self.dummy_sample
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
if scheduler_class != VQDiffusionScheduler:
|
||||
sample = self.dummy_sample
|
||||
scaled_sample = scheduler.scale_model_input(sample, 0.0)
|
||||
self.assertEqual(sample.shape, scaled_sample.shape)
|
||||
|
||||
def test_add_noise_device(self):
|
||||
for scheduler_class in self.scheduler_classes:
|
||||
@@ -1238,3 +1274,53 @@ class IPNDMSchedulerTest(SchedulerCommonTest):
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_mean.item() - 2540529) < 10
|
||||
|
||||
|
||||
class VQDiffusionSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (VQDiffusionScheduler,)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_vec_classes": 4097,
|
||||
"num_train_timesteps": 100,
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def dummy_sample(self, num_vec_classes):
|
||||
batch_size = 4
|
||||
height = 8
|
||||
width = 8
|
||||
|
||||
sample = torch.randint(0, num_vec_classes, (batch_size, height * width))
|
||||
|
||||
return sample
|
||||
|
||||
@property
|
||||
def dummy_sample_deter(self):
|
||||
assert False
|
||||
|
||||
def dummy_model(self, num_vec_classes):
|
||||
def model(sample, t, *args):
|
||||
batch_size, num_latent_pixels = sample.shape
|
||||
logits = torch.rand((batch_size, num_vec_classes - 1, num_latent_pixels))
|
||||
return_value = F.log_softmax(logits.double(), dim=1).float()
|
||||
return return_value
|
||||
|
||||
return model
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [2, 5, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_num_vec_classes(self):
|
||||
for num_vec_classes in [5, 100, 1000, 4000]:
|
||||
self.check_over_configs(num_vec_classes=num_vec_classes)
|
||||
|
||||
def test_time_indices(self):
|
||||
for t in [0, 50, 99]:
|
||||
self.check_over_forward(time_step=t)
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user