1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

DiT Pipeline (#1806)

* added dit model

* import

* initial pipeline

* initial convert script

* initial pipeline

* make style

* raise valueerror

* single function

* rename classes

* use DDIMScheduler

* timesteps embedder

* samples to cpu

* fix var names

* fix numpy type

* use timesteps class for proj

* fix typo

* fix arg name

* flip_sin_to_cos and better var names

* fix C shape cal

* make style

* remove unused imports

* cleanup

* add back patch_size

* initial dit doc

* typo

* Update docs/source/api/pipelines/dit.mdx

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* added copyright license headers

* added example usage and toc

* fix variable names asserts

* remove comment

* added docs

* fix typo

* upstream changes

* set proper device for drop_ids

* added initial dit pipeline test

* update docs

* fix imports

* make fix-copies

* isort

* fix imports

* get rid of more magic numbers

* fix code when guidance is off

* remove block_kwargs

* cleanup script

* removed to_2tuple

* use FeedForward class instead of another MLP

* style

* work on mergint DiTBlock with BasicTransformerBlock

* added missing final_dropout and args to BasicTransformerBlock

* use norm from block

* fix arg

* remove unused arg

* fix call to class_embedder

* use timesteps

* make style

* attn_output gets multiplied

* removed commented code

* use Transformer2D

* use self.is_input_patches

* fix flags

* fixed conversion to use Transformer2DModel

* fixes for pipeline

* remove dit.py

* fix timesteps device

* use randn_tensor and fix fp16 inf.

* timesteps_emb already the right dtype

* fix dit test class

* fix test and style

* fix norm2 usage in vq-diffusion

* added author names to pipeline and lmagenet labels link

* fix tests

* use norm_type as string

* rename dit to transformer

* fix name

* fix test

* set  norm_type = "layer" by default

* fix tests

* do not skip common tests

* Update src/diffusers/models/attention.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* revert AdaLayerNorm API

* fix norm_type name

* make sure all components are in eval mode

* revert norm2 API

* compact

* finish deprecation

* add slow tests

* remove @

* refactor some stuff

* upload

* Update src/diffusers/pipelines/dit/pipeline_dit.py

* finish more

* finish docs

* improve docs

* finish docs

Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: William Berman <WLBberman@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Kashif Rasul
2023-01-17 23:09:29 +01:00
committed by GitHub
parent 7e29b747f9
commit 37d113cce7
51 changed files with 995 additions and 235 deletions

View File

@@ -106,6 +106,8 @@
title: DDIM
- local: api/pipelines/ddpm
title: DDPM
- local: api/pipelines/dit
title: DiT
- local: api/pipelines/latent_diffusion
title: Latent Diffusion
- local: api/pipelines/paint_by_example

View File

@@ -0,0 +1,59 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->
# [Scalable Diffusion Models with Transformers](https://www.wpeebles.com/DiT) (DiT)
## Overview
[Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (DiT) by William Peebles and Saining Xie.
The abstract of the paper is the following:
*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.*
The original codebase of this paper can be found here: [facebookresearch/dit](https://github.com/facebookresearch/dit).
## Available Pipelines:
| Pipeline | Tasks | Colab
|---|---|:---:|
| [pipeline_dit.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py) | *Conditional Image Generation* | - |
## Usage example
```python
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
import torch
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
# pick words from Imagenet class labels
pipe.labels # to print all available words
# pick words that exist in ImageNet
words = ["white shark", "umbrella"]
class_ids = pipe.get_label_ids(words)
generator = torch.manual_seed(33)
output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)
image = output.images[0] # label 'white shark'
```
## DiTPipeline
[[autodoc]] DiTPipeline
- all
- __call__

View File

@@ -37,6 +37,7 @@ To this end, the design of schedulers is such that:
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
- Many diffusion pipelines, such as [`StableDiffusionPipeline`] and [`DiTPipeline`] can use any of [`KarrasDiffusionSchedulers`]
## Schedulers Summary
@@ -80,4 +81,6 @@ The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
### KarrasDiffusionSchedulers
[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers

View File

@@ -0,0 +1,162 @@
import argparse
import os
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
from torchvision.datasets.utils import download_url
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
def download_model(model_name):
"""
Downloads a pre-trained DiT model from the web.
"""
local_path = f"pretrained_models/{model_name}"
if not os.path.isfile(local_path):
os.makedirs("pretrained_models", exist_ok=True)
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
download_url(web_path, "pretrained_models")
model = torch.load(local_path, map_location=lambda storage, loc: storage)
return model
def main(args):
state_dict = download_model(pretrained_models[args.image_size])
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
state_dict.pop("x_embedder.proj.weight")
state_dict.pop("x_embedder.proj.bias")
for depth in range(28):
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
"t_embedder.mlp.0.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
"t_embedder.mlp.0.bias"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
"t_embedder.mlp.2.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
"t_embedder.mlp.2.bias"
]
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
"y_embedder.embedding_table.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
f"blocks.{depth}.adaLN_modulation.1.weight"
]
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
f"blocks.{depth}.adaLN_modulation.1.bias"
]
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
f"blocks.{depth}.attn.proj.weight"
]
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
state_dict.pop("t_embedder.mlp.0.weight")
state_dict.pop("t_embedder.mlp.0.bias")
state_dict.pop("t_embedder.mlp.2.weight")
state_dict.pop("t_embedder.mlp.2.bias")
state_dict.pop("y_embedder.embedding_table.weight")
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
state_dict.pop("final_layer.linear.weight")
state_dict.pop("final_layer.linear.bias")
state_dict.pop("final_layer.adaLN_modulation.1.weight")
state_dict.pop("final_layer.adaLN_modulation.1.bias")
# DiT XL/2
transformer = Transformer2DModel(
sample_size=args.image_size // 8,
num_layers=28,
attention_head_dim=72,
in_channels=4,
out_channels=8,
patch_size=2,
attention_bias=True,
num_attention_heads=16,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_elementwise_affine=False,
)
transformer.load_state_dict(state_dict, strict=True)
scheduler = DDIMScheduler(
num_train_timesteps=1000,
beta_schedule="linear",
prediction_type="epsilon",
clip_sample=False,
)
vae = AutoencoderKL.from_pretrained(args.vae_model)
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
if args.save:
pipeline.save_pretrained(args.checkpoint_path)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--image_size",
default=256,
type=int,
required=False,
help="Image size of pretrained model, either 256 or 512.",
)
parser.add_argument(
"--vae_model",
default="stabilityai/sd-vae-ft-ema",
type=str,
required=False,
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
)
parser.add_argument(
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
)
parser.add_argument(
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
)
args = parser.parse_args()
main(args)

View File

@@ -57,6 +57,7 @@ else:
DDIMPipeline,
DDPMPipeline,
DiffusionPipeline,
DiTPipeline,
ImagePipelineOutput,
KarrasVePipeline,
LDMPipeline,

View File

@@ -20,6 +20,7 @@ from torch import nn
from ..utils.import_utils import is_xformers_available
from .cross_attention import CrossAttention
from .embeddings import CombinedTimestepLabelEmbeddings
if is_xformers_available():
@@ -196,10 +197,21 @@ class BasicTransformerBlock(nn.Module):
attention_bias: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
):
super().__init__()
self.only_cross_attention = only_cross_attention
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
raise ValueError(
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
)
# 1. Self-Attn
self.attn1 = CrossAttention(
@@ -212,7 +224,7 @@ class BasicTransformerBlock(nn.Module):
upcast_attention=upcast_attention,
)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
# 2. Cross-Attn
if cross_attention_dim is not None:
@@ -228,15 +240,27 @@ class BasicTransformerBlock(nn.Module):
else:
self.attn2 = None
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
if self.use_ada_layer_norm:
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
elif self.use_ada_layer_norm_zero:
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
else:
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
if cross_attention_dim is not None:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = (
AdaLayerNorm(dim, num_embeds_ada_norm)
if self.use_ada_layer_norm
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
)
else:
self.norm2 = None
# 3. Feed-forward
self.norm3 = nn.LayerNorm(dim)
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
def forward(
self,
@@ -245,11 +269,18 @@ class BasicTransformerBlock(nn.Module):
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
class_labels=None,
):
if self.use_ada_layer_norm:
norm_hidden_states = self.norm1(hidden_states, timestep)
elif self.use_ada_layer_norm_zero:
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
)
else:
norm_hidden_states = self.norm1(hidden_states)
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
attn_output = self.attn1(
norm_hidden_states,
@@ -257,13 +288,16 @@ class BasicTransformerBlock(nn.Module):
attention_mask=attention_mask,
**cross_attention_kwargs,
)
if self.use_ada_layer_norm_zero:
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
# 2. Cross-Attention
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
@@ -273,7 +307,17 @@ class BasicTransformerBlock(nn.Module):
hidden_states = attn_output + hidden_states
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
norm_hidden_states = self.norm3(hidden_states)
if self.use_ada_layer_norm_zero:
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(norm_hidden_states)
if self.use_ada_layer_norm_zero:
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = ff_output + hidden_states
return hidden_states
@@ -288,6 +332,7 @@ class FeedForward(nn.Module):
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
"""
def __init__(
@@ -297,6 +342,7 @@ class FeedForward(nn.Module):
mult: int = 4,
dropout: float = 0.0,
activation_fn: str = "geglu",
final_dropout: bool = False,
):
super().__init__()
inner_dim = int(dim * mult)
@@ -304,6 +350,8 @@ class FeedForward(nn.Module):
if activation_fn == "gelu":
act_fn = GELU(dim, inner_dim)
if activation_fn == "gelu-approximate":
act_fn = GELU(dim, inner_dim, approximate="tanh")
elif activation_fn == "geglu":
act_fn = GEGLU(dim, inner_dim)
elif activation_fn == "geglu-approximate":
@@ -316,6 +364,9 @@ class FeedForward(nn.Module):
self.net.append(nn.Dropout(dropout))
# project out
self.net.append(nn.Linear(inner_dim, dim_out))
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
if final_dropout:
self.net.append(nn.Dropout(dropout))
def forward(self, hidden_states):
for module in self.net:
@@ -325,18 +376,19 @@ class FeedForward(nn.Module):
class GELU(nn.Module):
r"""
GELU activation function
GELU activation function with tanh approximation support with `approximate="tanh"`.
"""
def __init__(self, dim_in: int, dim_out: int):
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out)
self.approximate = approximate
def gelu(self, gate):
if gate.device.type != "mps":
return F.gelu(gate)
return F.gelu(gate, approximate=self.approximate)
# mps: gelu is not implemented for float16
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
def forward(self, hidden_states):
hidden_states = self.proj(hidden_states)
@@ -344,7 +396,6 @@ class GELU(nn.Module):
return hidden_states
# feedforward
class GEGLU(nn.Module):
r"""
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
@@ -402,3 +453,24 @@ class AdaLayerNorm(nn.Module):
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
return x
class AdaLayerNormZero(nn.Module):
"""
Norm layer adaptive layer norm zero (adaLN-Zero).
"""
def __init__(self, embedding_dim, num_embeddings):
super().__init__()
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, class_labels, hidden_dtype=None):
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp

View File

@@ -61,6 +61,96 @@ def get_timestep_embedding(
return emb
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
"""
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h) # here w goes first
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token and extra_tokens > 0:
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
return pos_embed
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
# use half of dimensions to encode grid_h
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
return emb
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
"""
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
"""
if embed_dim % 2 != 0:
raise ValueError("embed_dim must be divisible by 2")
omega = np.arange(embed_dim // 2, dtype=np.float64)
omega /= embed_dim / 2.0
omega = 1.0 / 10000**omega # (D/2,)
pos = pos.reshape(-1) # (M,)
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
emb_sin = np.sin(out) # (M, D/2)
emb_cos = np.cos(out) # (M, D/2)
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
return emb
class PatchEmbed(nn.Module):
"""2D Image to Patch Embedding"""
def __init__(
self,
height=224,
width=224,
patch_size=16,
in_channels=3,
embed_dim=768,
layer_norm=False,
flatten=True,
bias=True,
):
super().__init__()
num_patches = (height // patch_size) * (width // patch_size)
self.flatten = flatten
self.layer_norm = layer_norm
self.proj = nn.Conv2d(
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
)
if layer_norm:
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
else:
self.norm = None
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
def forward(self, latent):
latent = self.proj(latent)
if self.flatten:
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
if self.layer_norm:
latent = self.norm(latent)
return latent + self.pos_embed
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
super().__init__()
@@ -198,3 +288,58 @@ class ImagePositionalEmbeddings(nn.Module):
emb = emb + pos_emb[:, : emb.shape[1], :]
return emb
class LabelEmbedding(nn.Module):
"""
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
Args:
num_classes (`int`): The number of classes.
hidden_size (`int`): The size of the vector embeddings.
dropout_prob (`float`): The probability of dropping a label.
"""
def __init__(self, num_classes, hidden_size, dropout_prob):
super().__init__()
use_cfg_embedding = dropout_prob > 0
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
self.num_classes = num_classes
self.dropout_prob = dropout_prob
def token_drop(self, labels, force_drop_ids=None):
"""
Drops labels to enable classifier-free guidance.
"""
if force_drop_ids is None:
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
else:
drop_ids = torch.tensor(force_drop_ids == 1)
labels = torch.where(drop_ids, self.num_classes, labels)
return labels
def forward(self, labels, force_drop_ids=None):
use_dropout = self.dropout_prob > 0
if (self.training and use_dropout) or (force_drop_ids is not None):
labels = self.token_drop(labels, force_drop_ids)
embeddings = self.embedding_table(labels)
return embeddings
class CombinedTimestepLabelEmbeddings(nn.Module):
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
super().__init__()
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
def forward(self, timestep, class_labels, hidden_dtype=None):
timesteps_proj = self.time_proj(timestep)
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
class_labels = self.class_embedder(class_labels) # (N, D)
conditioning = timesteps_emb + class_labels # (N, D)
return conditioning

View File

@@ -20,8 +20,9 @@ from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import BaseOutput
from ..utils import BaseOutput, deprecate
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed
from .modeling_utils import ModelMixin
@@ -81,6 +82,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
@@ -88,11 +90,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
):
super().__init__()
self.use_linear_projection = use_linear_projection
@@ -102,18 +107,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = in_channels is not None
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized:
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is not None."
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
@@ -137,6 +159,20 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
@@ -152,13 +188,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continous projections
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
@@ -166,12 +206,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches:
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
timestep=None,
class_labels=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
@@ -185,6 +230,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
conditioning.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
@@ -195,7 +243,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
"""
# 1. Input
if self.is_input_continuous:
batch, channel, height, width = hidden_states.shape
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
@@ -209,6 +257,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states = self.proj_in(hidden_states)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
hidden_states = self.pos_embed(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
@@ -217,6 +267,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
@@ -237,6 +288,24 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
elif self.is_input_patches:
# TODO: cleanup!
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
# unpatchify
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)

View File

@@ -18,6 +18,7 @@ else:
from .dance_diffusion import DanceDiffusionPipeline
from .ddim import DDIMPipeline
from .ddpm import DDPMPipeline
from .dit import DiTPipeline
from .latent_diffusion import LDMSuperResolutionPipeline
from .latent_diffusion_uncond import LDMPipeline
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput

View File

@@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -91,14 +84,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
text_encoder: RobertaSeriesModelWithTransformation,
tokenizer: XLMRobertaTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -25,14 +25,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -129,14 +122,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
text_encoder: RobertaSeriesModelWithTransformation,
tokenizer: XLMRobertaTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -0,0 +1 @@
from .pipeline_dit import DiTPipeline

View File

@@ -0,0 +1,199 @@
# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
# William Peebles and Saining Xie
#
# Copyright (c) 2021 OpenAI
# MIT License
#
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, List, Optional, Tuple, Union
import torch
from ...models import AutoencoderKL, Transformer2DModel
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
class DiTPipeline(DiffusionPipeline):
r"""
This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Parameters:
transformer ([`Transformer2DModel`]):
Class conditioned Transformer in Diffusion model to denoise the encoded image latents.
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
scheduler ([`DDIMScheduler`]):
A scheduler to be used in combination with `dit` to denoise the encoded image latents.
"""
def __init__(
self,
transformer: Transformer2DModel,
vae: AutoencoderKL,
scheduler: KarrasDiffusionSchedulers,
id2label: Optional[Dict[int, str]] = None,
):
super().__init__()
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
# create a imagenet -> id dictionary for easier use
self.labels = {}
if id2label is not None:
for key, value in id2label.items():
for label in value.split(","):
self.labels[label.lstrip().rstrip()] = int(key)
self.labels = dict(sorted(self.labels.items()))
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
r"""
Map label strings, *e.g.* from ImageNet, to corresponding class ids.
Parameters:
label (`str` or `dict` of `str`): label strings to be mapped to class ids.
Returns:
`list` of `int`: Class ids to be processed by pipeline.
"""
if not isinstance(label, list):
label = list(label)
for l in label:
if l not in self.labels:
raise ValueError(
f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
)
return [self.labels[l] for l in label]
@torch.no_grad()
def __call__(
self,
class_labels: List[int],
guidance_scale: float = 4.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
num_inference_steps: int = 50,
output_type: Optional[str] = "pil",
return_dict: bool = True,
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Args:
class_labels (List[int]):
List of imagenet class labels for the images to be generated.
guidance_scale (`float`, *optional*, defaults to 4.0):
Scale of the guidance signal.
generator (`torch.Generator`, *optional*):
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
deterministic.
num_inference_steps (`int`, *optional*, defaults to 250):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
"""
batch_size = len(class_labels)
latent_size = self.transformer.config.sample_size
latent_channels = self.transformer.config.in_channels
latents = randn_tensor(
shape=(batch_size, latent_channels, latent_size, latent_size),
generator=generator,
device=self.device,
dtype=self.transformer.dtype,
)
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
class_labels = torch.tensor(class_labels, device=self.device).reshape(-1)
class_null = torch.tensor([1000] * batch_size, device=self.device)
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
# set step values
self.scheduler.set_timesteps(num_inference_steps)
for t in self.progress_bar(self.scheduler.timesteps):
if guidance_scale > 1:
half = latent_model_input[: len(latent_model_input) // 2]
latent_model_input = torch.cat([half, half], dim=0)
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
timesteps = t
if not torch.is_tensor(timesteps):
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = latent_model_input.device.type == "mps"
if isinstance(timesteps, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(latent_model_input.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(latent_model_input.shape[0])
# predict noise model_output
noise_pred = self.transformer(
latent_model_input, timestep=timesteps, class_labels=class_labels_input
).sample
# perform guidance
if guidance_scale > 1:
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
noise_pred = torch.cat([eps, rest], dim=1)
# learned sigma
if self.transformer.config.out_channels // 2 == latent_channels:
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
else:
model_output = noise_pred
# compute previous image: x_t -> x_t-1
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
if guidance_scale > 1:
latents, _ = latent_model_input.chunk(2, dim=0)
else:
latents = latent_model_input
latents = 1 / 0.18215 * latents
samples = self.vae.decode(latents).sample
samples = (samples / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
samples = self.numpy_to_pil(samples)
if not return_dict:
return (samples,)
return ImagePipelineOutput(images=samples)

View File

@@ -22,14 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
@@ -88,14 +81,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -25,14 +25,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -91,14 +84,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
depth_estimator: DPTForDepthEstimation,
feature_extractor: DPTFeatureExtractor,
):

View File

@@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
@@ -73,14 +66,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
vae: AutoencoderKL,
image_encoder: CLIPVisionModelWithProjection,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import (
PIL_INTERPOLATION,
deprecate,
@@ -133,14 +126,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
@@ -173,7 +173,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionPipelineOutput
@@ -100,14 +93,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -22,7 +22,7 @@ import PIL
from transformers import CLIPTextModel, CLIPTokenizer
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -84,7 +84,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
low_res_scheduler: DDPMScheduler,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
max_noise_level: int = 350,
):
super().__init__()

View File

@@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from ...configuration_utils import FrozenDict
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline
from . import StableDiffusionSafePipelineOutput
@@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
scheduler: Union[
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
],
scheduler: KarrasDiffusionSchedulers,
safety_checker: SafeStableDiffusionSafetyChecker,
feature_extractor: CLIPFeatureExtractor,
requires_safety_checker: bool = True,

View File

@@ -7,7 +7,7 @@ import PIL.Image
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import logging
from ..pipeline_utils import DiffusionPipeline
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
@@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
def __init__(
self,
@@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()

View File

@@ -28,7 +28,7 @@ from transformers import (
)
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
@@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"]
@@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(

View File

@@ -23,7 +23,7 @@ import PIL
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
@@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection
image_unet: UNet2DConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
def __init__(
self,
@@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
image_encoder: CLIPVisionModelWithProjection,
image_unet: UNet2DConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(

View File

@@ -21,7 +21,7 @@ import torch.utils.checkpoint
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_accelerate_available, logging, randn_tensor
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from .modeling_text_unet import UNetFlatConditionModel
@@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel
text_unet: UNetFlatConditionModel
vae: AutoencoderKL
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
scheduler: KarrasDiffusionSchedulers
_optional_components = ["text_unet"]
@@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
image_unet: UNet2DConditionModel,
text_unet: UNetFlatConditionModel,
vae: AutoencoderKL,
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
scheduler: KarrasDiffusionSchedulers,
):
super().__init__()
self.register_modules(

View File

@@ -39,7 +39,7 @@ else:
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler
try:
@@ -55,7 +55,12 @@ else:
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
from .scheduling_pndm_flax import FlaxPNDMScheduler
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
from .scheduling_utils_flax import (
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
)
try:

View File

@@ -23,8 +23,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
@@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1

View File

@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
@@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype

View File

@@ -22,8 +22,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, deprecate, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
@@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1

View File

@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
@@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype

View File

@@ -22,8 +22,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -21,8 +21,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import deprecate
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
order = 1

View File

@@ -24,8 +24,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import deprecate
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
@@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
_deprecated_kwargs = ["predict_epsilon"]
dtype: jnp.dtype

View File

@@ -21,8 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -19,8 +19,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -19,8 +19,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput, logging, randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -18,8 +18,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
@@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config

View File

@@ -18,8 +18,8 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from ..utils import randn_tensor
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
@@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config

View File

@@ -18,8 +18,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
@@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 2
@register_to_config

View File

@@ -21,8 +21,8 @@ import torch
from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from .scheduling_utils import SchedulerMixin
from ..utils import BaseOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@dataclass
@@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
https://imagen.research.google/video/paper.pdf)
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -21,8 +21,8 @@ from scipy import integrate
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
broadcast_to_shape_from_left,
@@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype

View File

@@ -21,8 +21,7 @@ import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
from .scheduling_utils import SchedulerMixin, SchedulerOutput
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
@@ -92,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
"""
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
order = 1
@register_to_config

View File

@@ -23,8 +23,8 @@ import jax.numpy as jnp
from ..configuration_utils import ConfigMixin, register_to_config
from .scheduling_utils_flax import (
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CommonSchedulerState,
FlaxKarrasDiffusionSchedulers,
FlaxSchedulerMixin,
FlaxSchedulerOutput,
add_noise_common,
@@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
the `dtype` used for params and computation.
"""
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
dtype: jnp.dtype
pndm_order: int

View File

@@ -14,6 +14,7 @@
import importlib
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Union
import torch
@@ -24,6 +25,21 @@ from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
class KarrasDiffusionSchedulers(Enum):
DDIMScheduler = 1
DDPMScheduler = 2
PNDMScheduler = 3
LMSDiscreteScheduler = 4
EulerDiscreteScheduler = 5
HeunDiscreteScheduler = 6
EulerAncestralDiscreteScheduler = 7
DPMSolverMultistepScheduler = 8
DPMSolverSinglestepScheduler = 9
KDPM2DiscreteScheduler = 10
KDPM2AncestralDiscreteScheduler = 11
DEISMultistepScheduler = 12
@dataclass
class SchedulerOutput(BaseOutput):
"""

View File

@@ -15,16 +15,24 @@ import importlib
import math
import os
from dataclasses import dataclass
from enum import Enum
from typing import Any, Dict, Optional, Tuple, Union
import flax
import jax.numpy as jnp
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
from ..utils import BaseOutput
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
class FlaxKarrasDiffusionSchedulers(Enum):
FlaxDDIMScheduler = 1
FlaxDDPMScheduler = 2
FlaxPNDMScheduler = 3
FlaxLMSDiscreteScheduler = 4
FlaxDPMSolverMultistepScheduler = 5
@dataclass

View File

@@ -19,7 +19,6 @@ from packaging import version
from .. import __version__
from .constants import (
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
CONFIG_NAME,
DIFFUSERS_CACHE,
DIFFUSERS_DYNAMIC_MODULE_NAME,

View File

@@ -30,18 +30,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
DIFFUSERS_CACHE = default_cache_path
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
"DDIMScheduler",
"DDPMScheduler",
"PNDMScheduler",
"LMSDiscreteScheduler",
"EulerDiscreteScheduler",
"HeunDiscreteScheduler",
"EulerAncestralDiscreteScheduler",
"DPMSolverMultistepScheduler",
"DPMSolverSinglestepScheduler",
"KDPM2DiscreteScheduler",
"KDPM2AncestralDiscreteScheduler",
"DEISMultistepScheduler",
]

View File

@@ -227,6 +227,21 @@ class DiffusionPipeline(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class DiTPipeline(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class ImagePipelineOutput(metaclass=DummyObject):
_backends = ["torch"]

View File

View File

@@ -0,0 +1,135 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
from diffusers.utils import load_numpy, slow
from diffusers.utils.testing_utils import require_torch_gpu
from ...test_pipelines_common import PipelineTesterMixin
torch.backends.cuda.matmul.allow_tf32 = False
class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
pipeline_class = DiTPipeline
test_cpu_offload = False
def get_dummy_components(self):
torch.manual_seed(0)
transformer = Transformer2DModel(
sample_size=4,
num_layers=2,
patch_size=2,
attention_head_dim=2,
num_attention_heads=2,
in_channels=4,
out_channels=8,
attention_bias=True,
activation_fn="gelu-approximate",
num_embeds_ada_norm=1000,
norm_type="ada_norm_zero",
norm_elementwise_affine=False,
)
vae = AutoencoderKL()
scheduler = DDIMScheduler()
components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler}
return components
def get_dummy_inputs(self, device, seed=0):
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device=device).manual_seed(seed)
inputs = {
"class_labels": [1],
"generator": generator,
"num_inference_steps": 2,
"output_type": "numpy",
}
return inputs
def test_inference(self):
device = "cpu"
components = self.get_dummy_components()
pipe = self.pipeline_class(**components)
pipe.to(device)
pipe.set_progress_bar_config(disable=None)
inputs = self.get_dummy_inputs(device)
image = pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]
self.assertEqual(image.shape, (1, 4, 4, 3))
expected_slice = np.array(
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
)
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
def test_inference_batch_single_identical(self):
self._test_inference_batch_single_identical(relax_max_difference=True)
@require_torch_gpu
@slow
class DiTPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_dit_256(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.to("cuda")
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images
for word, image in zip(words, images):
expected_image = load_numpy(
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"
)
assert np.abs((expected_image - image).sum()) < 1e-3
def test_dit_512_fp16(self):
generator = torch.manual_seed(0)
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.to("cuda")
words = ["vase", "umbrella", "white shark", "white wolf"]
ids = pipe.get_label_ids(words)
images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images
for word, image in zip(words, images):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
f"/dit/{word}_fp16.npy"
)
assert np.abs((expected_image - image).sum()) < 1e-3

View File

@@ -36,7 +36,7 @@ class PipelineTesterMixin:
equivalence of dict and tuple outputs, etc.
"""
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image", "class_labels"]
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
num_inference_steps_args = ["num_inference_steps"]
@@ -194,8 +194,8 @@ class PipelineTesterMixin:
):
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
# RePaint can hardly be made deterministic since the scheduler is currently always
# indeterministic
# CycleDiffusion is also slighly undeterministic
# nondeterministic
# CycleDiffusion is also slightly nondeterministic
return
if test_max_difference is None:
@@ -515,7 +515,7 @@ class PipelineTesterMixin:
torch_device != "cuda" or not is_xformers_available(),
reason="XFormers attention is only available with CUDA and `xformers` installed",
)
def test_xformers_attention_forward_pass(self):
def test_xformers_attention_forwardGenerator_pass(self):
if not self.test_xformers_attention:
return