mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
DiT Pipeline (#1806)
* added dit model * import * initial pipeline * initial convert script * initial pipeline * make style * raise valueerror * single function * rename classes * use DDIMScheduler * timesteps embedder * samples to cpu * fix var names * fix numpy type * use timesteps class for proj * fix typo * fix arg name * flip_sin_to_cos and better var names * fix C shape cal * make style * remove unused imports * cleanup * add back patch_size * initial dit doc * typo * Update docs/source/api/pipelines/dit.mdx Co-authored-by: Suraj Patil <surajp815@gmail.com> * added copyright license headers * added example usage and toc * fix variable names asserts * remove comment * added docs * fix typo * upstream changes * set proper device for drop_ids * added initial dit pipeline test * update docs * fix imports * make fix-copies * isort * fix imports * get rid of more magic numbers * fix code when guidance is off * remove block_kwargs * cleanup script * removed to_2tuple * use FeedForward class instead of another MLP * style * work on mergint DiTBlock with BasicTransformerBlock * added missing final_dropout and args to BasicTransformerBlock * use norm from block * fix arg * remove unused arg * fix call to class_embedder * use timesteps * make style * attn_output gets multiplied * removed commented code * use Transformer2D * use self.is_input_patches * fix flags * fixed conversion to use Transformer2DModel * fixes for pipeline * remove dit.py * fix timesteps device * use randn_tensor and fix fp16 inf. * timesteps_emb already the right dtype * fix dit test class * fix test and style * fix norm2 usage in vq-diffusion * added author names to pipeline and lmagenet labels link * fix tests * use norm_type as string * rename dit to transformer * fix name * fix test * set norm_type = "layer" by default * fix tests * do not skip common tests * Update src/diffusers/models/attention.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * revert AdaLayerNorm API * fix norm_type name * make sure all components are in eval mode * revert norm2 API * compact * finish deprecation * add slow tests * remove @ * refactor some stuff * upload * Update src/diffusers/pipelines/dit/pipeline_dit.py * finish more * finish docs * improve docs * finish docs Co-authored-by: Suraj Patil <surajp815@gmail.com> Co-authored-by: William Berman <WLBberman@gmail.com> Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
@@ -106,6 +106,8 @@
|
||||
title: DDIM
|
||||
- local: api/pipelines/ddpm
|
||||
title: DDPM
|
||||
- local: api/pipelines/dit
|
||||
title: DiT
|
||||
- local: api/pipelines/latent_diffusion
|
||||
title: Latent Diffusion
|
||||
- local: api/pipelines/paint_by_example
|
||||
|
||||
59
docs/source/en/api/pipelines/dit.mdx
Normal file
59
docs/source/en/api/pipelines/dit.mdx
Normal file
@@ -0,0 +1,59 @@
|
||||
<!--Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
|
||||
the License. You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
|
||||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
|
||||
specific language governing permissions and limitations under the License.
|
||||
-->
|
||||
|
||||
# [Scalable Diffusion Models with Transformers](https://www.wpeebles.com/DiT) (DiT)
|
||||
|
||||
## Overview
|
||||
|
||||
[Scalable Diffusion Models with Transformers](https://arxiv.org/abs/2212.09748) (DiT) by William Peebles and Saining Xie.
|
||||
|
||||
The abstract of the paper is the following:
|
||||
|
||||
*We explore a new class of diffusion models based on the transformer architecture. We train latent diffusion models of images, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops -- through increased transformer depth/width or increased number of input tokens -- consistently have lower FID. In addition to possessing good scalability properties, our largest DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512x512 and 256x256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.*
|
||||
|
||||
The original codebase of this paper can be found here: [facebookresearch/dit](https://github.com/facebookresearch/dit).
|
||||
|
||||
## Available Pipelines:
|
||||
|
||||
| Pipeline | Tasks | Colab
|
||||
|---|---|:---:|
|
||||
| [pipeline_dit.py](https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/dit/pipeline_dit.py) | *Conditional Image Generation* | - |
|
||||
|
||||
|
||||
## Usage example
|
||||
|
||||
```python
|
||||
from diffusers import DiTPipeline, DPMSolverMultistepScheduler
|
||||
import torch
|
||||
|
||||
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe = pipe.to("cuda")
|
||||
|
||||
# pick words from Imagenet class labels
|
||||
pipe.labels # to print all available words
|
||||
|
||||
# pick words that exist in ImageNet
|
||||
words = ["white shark", "umbrella"]
|
||||
|
||||
class_ids = pipe.get_label_ids(words)
|
||||
|
||||
generator = torch.manual_seed(33)
|
||||
output = pipe(class_labels=class_ids, num_inference_steps=25, generator=generator)
|
||||
|
||||
image = output.images[0] # label 'white shark'
|
||||
```
|
||||
|
||||
## DiTPipeline
|
||||
[[autodoc]] DiTPipeline
|
||||
- all
|
||||
- __call__
|
||||
@@ -37,6 +37,7 @@ To this end, the design of schedulers is such that:
|
||||
|
||||
- Schedulers can be used interchangeably between diffusion models in inference to find the preferred trade-off between speed and generation quality.
|
||||
- Schedulers are currently by default in PyTorch, but are designed to be framework independent (partial Jax support currently exists).
|
||||
- Many diffusion pipelines, such as [`StableDiffusionPipeline`] and [`DiTPipeline`] can use any of [`KarrasDiffusionSchedulers`]
|
||||
|
||||
## Schedulers Summary
|
||||
|
||||
@@ -80,4 +81,6 @@ The class [`SchedulerOutput`] contains the outputs from any schedulers `step(...
|
||||
|
||||
[[autodoc]] schedulers.scheduling_utils.SchedulerOutput
|
||||
|
||||
### KarrasDiffusionSchedulers
|
||||
|
||||
[[autodoc]] schedulers.scheduling_utils.KarrasDiffusionSchedulers
|
||||
|
||||
162
scripts/convert_dit_to_diffusers.py
Normal file
162
scripts/convert_dit_to_diffusers.py
Normal file
@@ -0,0 +1,162 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, Transformer2DModel
|
||||
from torchvision.datasets.utils import download_url
|
||||
|
||||
|
||||
pretrained_models = {512: "DiT-XL-2-512x512.pt", 256: "DiT-XL-2-256x256.pt"}
|
||||
|
||||
|
||||
def download_model(model_name):
|
||||
"""
|
||||
Downloads a pre-trained DiT model from the web.
|
||||
"""
|
||||
local_path = f"pretrained_models/{model_name}"
|
||||
if not os.path.isfile(local_path):
|
||||
os.makedirs("pretrained_models", exist_ok=True)
|
||||
web_path = f"https://dl.fbaipublicfiles.com/DiT/models/{model_name}"
|
||||
download_url(web_path, "pretrained_models")
|
||||
model = torch.load(local_path, map_location=lambda storage, loc: storage)
|
||||
return model
|
||||
|
||||
|
||||
def main(args):
|
||||
state_dict = download_model(pretrained_models[args.image_size])
|
||||
|
||||
state_dict["pos_embed.proj.weight"] = state_dict["x_embedder.proj.weight"]
|
||||
state_dict["pos_embed.proj.bias"] = state_dict["x_embedder.proj.bias"]
|
||||
state_dict.pop("x_embedder.proj.weight")
|
||||
state_dict.pop("x_embedder.proj.bias")
|
||||
|
||||
for depth in range(28):
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.weight"] = state_dict[
|
||||
"t_embedder.mlp.0.weight"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_1.bias"] = state_dict[
|
||||
"t_embedder.mlp.0.bias"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.weight"] = state_dict[
|
||||
"t_embedder.mlp.2.weight"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.emb.timestep_embedder.linear_2.bias"] = state_dict[
|
||||
"t_embedder.mlp.2.bias"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.emb.class_embedder.embedding_table.weight"] = state_dict[
|
||||
"y_embedder.embedding_table.weight"
|
||||
]
|
||||
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.linear.weight"] = state_dict[
|
||||
f"blocks.{depth}.adaLN_modulation.1.weight"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.norm1.linear.bias"] = state_dict[
|
||||
f"blocks.{depth}.adaLN_modulation.1.bias"
|
||||
]
|
||||
|
||||
q, k, v = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.weight"], 3, dim=0)
|
||||
q_bias, k_bias, v_bias = torch.chunk(state_dict[f"blocks.{depth}.attn.qkv.bias"], 3, dim=0)
|
||||
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_q.weight"] = q
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_q.bias"] = q_bias
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_k.weight"] = k
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_k.bias"] = k_bias
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_v.weight"] = v
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_v.bias"] = v_bias
|
||||
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.weight"] = state_dict[
|
||||
f"blocks.{depth}.attn.proj.weight"
|
||||
]
|
||||
state_dict[f"transformer_blocks.{depth}.attn1.to_out.0.bias"] = state_dict[f"blocks.{depth}.attn.proj.bias"]
|
||||
|
||||
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.weight"] = state_dict[f"blocks.{depth}.mlp.fc1.weight"]
|
||||
state_dict[f"transformer_blocks.{depth}.ff.net.0.proj.bias"] = state_dict[f"blocks.{depth}.mlp.fc1.bias"]
|
||||
state_dict[f"transformer_blocks.{depth}.ff.net.2.weight"] = state_dict[f"blocks.{depth}.mlp.fc2.weight"]
|
||||
state_dict[f"transformer_blocks.{depth}.ff.net.2.bias"] = state_dict[f"blocks.{depth}.mlp.fc2.bias"]
|
||||
|
||||
state_dict.pop(f"blocks.{depth}.attn.qkv.weight")
|
||||
state_dict.pop(f"blocks.{depth}.attn.qkv.bias")
|
||||
state_dict.pop(f"blocks.{depth}.attn.proj.weight")
|
||||
state_dict.pop(f"blocks.{depth}.attn.proj.bias")
|
||||
state_dict.pop(f"blocks.{depth}.mlp.fc1.weight")
|
||||
state_dict.pop(f"blocks.{depth}.mlp.fc1.bias")
|
||||
state_dict.pop(f"blocks.{depth}.mlp.fc2.weight")
|
||||
state_dict.pop(f"blocks.{depth}.mlp.fc2.bias")
|
||||
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.weight")
|
||||
state_dict.pop(f"blocks.{depth}.adaLN_modulation.1.bias")
|
||||
|
||||
state_dict.pop("t_embedder.mlp.0.weight")
|
||||
state_dict.pop("t_embedder.mlp.0.bias")
|
||||
state_dict.pop("t_embedder.mlp.2.weight")
|
||||
state_dict.pop("t_embedder.mlp.2.bias")
|
||||
state_dict.pop("y_embedder.embedding_table.weight")
|
||||
|
||||
state_dict["proj_out_1.weight"] = state_dict["final_layer.adaLN_modulation.1.weight"]
|
||||
state_dict["proj_out_1.bias"] = state_dict["final_layer.adaLN_modulation.1.bias"]
|
||||
state_dict["proj_out_2.weight"] = state_dict["final_layer.linear.weight"]
|
||||
state_dict["proj_out_2.bias"] = state_dict["final_layer.linear.bias"]
|
||||
|
||||
state_dict.pop("final_layer.linear.weight")
|
||||
state_dict.pop("final_layer.linear.bias")
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.weight")
|
||||
state_dict.pop("final_layer.adaLN_modulation.1.bias")
|
||||
|
||||
# DiT XL/2
|
||||
transformer = Transformer2DModel(
|
||||
sample_size=args.image_size // 8,
|
||||
num_layers=28,
|
||||
attention_head_dim=72,
|
||||
in_channels=4,
|
||||
out_channels=8,
|
||||
patch_size=2,
|
||||
attention_bias=True,
|
||||
num_attention_heads=16,
|
||||
activation_fn="gelu-approximate",
|
||||
num_embeds_ada_norm=1000,
|
||||
norm_type="ada_norm_zero",
|
||||
norm_elementwise_affine=False,
|
||||
)
|
||||
transformer.load_state_dict(state_dict, strict=True)
|
||||
|
||||
scheduler = DDIMScheduler(
|
||||
num_train_timesteps=1000,
|
||||
beta_schedule="linear",
|
||||
prediction_type="epsilon",
|
||||
clip_sample=False,
|
||||
)
|
||||
|
||||
vae = AutoencoderKL.from_pretrained(args.vae_model)
|
||||
|
||||
pipeline = DiTPipeline(transformer=transformer, vae=vae, scheduler=scheduler)
|
||||
|
||||
if args.save:
|
||||
pipeline.save_pretrained(args.checkpoint_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--image_size",
|
||||
default=256,
|
||||
type=int,
|
||||
required=False,
|
||||
help="Image size of pretrained model, either 256 or 512.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vae_model",
|
||||
default="stabilityai/sd-vae-ft-ema",
|
||||
type=str,
|
||||
required=False,
|
||||
help="Path to pretrained VAE model, either stabilityai/sd-vae-ft-mse or stabilityai/sd-vae-ft-ema.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save", default=True, type=bool, required=False, help="Whether to save the converted pipeline or not."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--checkpoint_path", default=None, type=str, required=True, help="Path to the output pipeline."
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args)
|
||||
@@ -57,6 +57,7 @@ else:
|
||||
DDIMPipeline,
|
||||
DDPMPipeline,
|
||||
DiffusionPipeline,
|
||||
DiTPipeline,
|
||||
ImagePipelineOutput,
|
||||
KarrasVePipeline,
|
||||
LDMPipeline,
|
||||
|
||||
@@ -20,6 +20,7 @@ from torch import nn
|
||||
|
||||
from ..utils.import_utils import is_xformers_available
|
||||
from .cross_attention import CrossAttention
|
||||
from .embeddings import CombinedTimestepLabelEmbeddings
|
||||
|
||||
|
||||
if is_xformers_available():
|
||||
@@ -196,10 +197,21 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_bias: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_elementwise_affine: bool = True,
|
||||
norm_type: str = "layer_norm",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.only_cross_attention = only_cross_attention
|
||||
self.use_ada_layer_norm = num_embeds_ada_norm is not None
|
||||
|
||||
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
|
||||
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
|
||||
|
||||
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
|
||||
raise ValueError(
|
||||
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
|
||||
f" define `num_embeds_ada_norm` if setting `norm_type` to {norm_type}."
|
||||
)
|
||||
|
||||
# 1. Self-Attn
|
||||
self.attn1 = CrossAttention(
|
||||
@@ -212,7 +224,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
|
||||
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
|
||||
|
||||
# 2. Cross-Attn
|
||||
if cross_attention_dim is not None:
|
||||
@@ -228,15 +240,27 @@ class BasicTransformerBlock(nn.Module):
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
if self.use_ada_layer_norm:
|
||||
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
|
||||
else:
|
||||
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
if cross_attention_dim is not None:
|
||||
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
|
||||
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
|
||||
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
|
||||
# the second cross attention block.
|
||||
self.norm2 = (
|
||||
AdaLayerNorm(dim, num_embeds_ada_norm)
|
||||
if self.use_ada_layer_norm
|
||||
else nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
)
|
||||
else:
|
||||
self.norm2 = None
|
||||
|
||||
# 3. Feed-forward
|
||||
self.norm3 = nn.LayerNorm(dim)
|
||||
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -245,11 +269,18 @@ class BasicTransformerBlock(nn.Module):
|
||||
timestep=None,
|
||||
attention_mask=None,
|
||||
cross_attention_kwargs=None,
|
||||
class_labels=None,
|
||||
):
|
||||
if self.use_ada_layer_norm:
|
||||
norm_hidden_states = self.norm1(hidden_states, timestep)
|
||||
elif self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
|
||||
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
else:
|
||||
norm_hidden_states = self.norm1(hidden_states)
|
||||
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
|
||||
attn_output = self.attn1(
|
||||
norm_hidden_states,
|
||||
@@ -257,13 +288,16 @@ class BasicTransformerBlock(nn.Module):
|
||||
attention_mask=attention_mask,
|
||||
**cross_attention_kwargs,
|
||||
)
|
||||
if self.use_ada_layer_norm_zero:
|
||||
attn_output = gate_msa.unsqueeze(1) * attn_output
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
|
||||
# 2. Cross-Attention
|
||||
attn_output = self.attn2(
|
||||
norm_hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
@@ -273,7 +307,17 @@ class BasicTransformerBlock(nn.Module):
|
||||
hidden_states = attn_output + hidden_states
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
norm_hidden_states = self.norm3(hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
|
||||
ff_output = self.ff(norm_hidden_states)
|
||||
|
||||
if self.use_ada_layer_norm_zero:
|
||||
ff_output = gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
hidden_states = ff_output + hidden_states
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -288,6 +332,7 @@ class FeedForward(nn.Module):
|
||||
mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -297,6 +342,7 @@ class FeedForward(nn.Module):
|
||||
mult: int = 4,
|
||||
dropout: float = 0.0,
|
||||
activation_fn: str = "geglu",
|
||||
final_dropout: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
@@ -304,6 +350,8 @@ class FeedForward(nn.Module):
|
||||
|
||||
if activation_fn == "gelu":
|
||||
act_fn = GELU(dim, inner_dim)
|
||||
if activation_fn == "gelu-approximate":
|
||||
act_fn = GELU(dim, inner_dim, approximate="tanh")
|
||||
elif activation_fn == "geglu":
|
||||
act_fn = GEGLU(dim, inner_dim)
|
||||
elif activation_fn == "geglu-approximate":
|
||||
@@ -316,6 +364,9 @@ class FeedForward(nn.Module):
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
# project out
|
||||
self.net.append(nn.Linear(inner_dim, dim_out))
|
||||
# FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
|
||||
if final_dropout:
|
||||
self.net.append(nn.Dropout(dropout))
|
||||
|
||||
def forward(self, hidden_states):
|
||||
for module in self.net:
|
||||
@@ -325,18 +376,19 @@ class FeedForward(nn.Module):
|
||||
|
||||
class GELU(nn.Module):
|
||||
r"""
|
||||
GELU activation function
|
||||
GELU activation function with tanh approximation support with `approximate="tanh"`.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int, dim_out: int):
|
||||
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none"):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(dim_in, dim_out)
|
||||
self.approximate = approximate
|
||||
|
||||
def gelu(self, gate):
|
||||
if gate.device.type != "mps":
|
||||
return F.gelu(gate)
|
||||
return F.gelu(gate, approximate=self.approximate)
|
||||
# mps: gelu is not implemented for float16
|
||||
return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
|
||||
return F.gelu(gate.to(dtype=torch.float32), approximate=self.approximate).to(dtype=gate.dtype)
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.proj(hidden_states)
|
||||
@@ -344,7 +396,6 @@ class GELU(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# feedforward
|
||||
class GEGLU(nn.Module):
|
||||
r"""
|
||||
A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
|
||||
@@ -402,3 +453,24 @@ class AdaLayerNorm(nn.Module):
|
||||
scale, shift = torch.chunk(emb, 2)
|
||||
x = self.norm(x) * (1 + scale) + shift
|
||||
return x
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
"""
|
||||
Norm layer adaptive layer norm zero (adaLN-Zero).
|
||||
"""
|
||||
|
||||
def __init__(self, embedding_dim, num_embeddings):
|
||||
super().__init__()
|
||||
|
||||
self.emb = CombinedTimestepLabelEmbeddings(num_embeddings, embedding_dim)
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
|
||||
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, timestep, class_labels, hidden_dtype=None):
|
||||
emb = self.linear(self.silu(self.emb(timestep, class_labels, hidden_dtype=hidden_dtype)))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.chunk(6, dim=1)
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
@@ -61,6 +61,96 @@ def get_timestep_embedding(
|
||||
return emb
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
|
||||
"""
|
||||
grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or
|
||||
[1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
|
||||
"""
|
||||
grid_h = np.arange(grid_size, dtype=np.float32)
|
||||
grid_w = np.arange(grid_size, dtype=np.float32)
|
||||
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
||||
grid = np.stack(grid, axis=0)
|
||||
|
||||
grid = grid.reshape([2, 1, grid_size, grid_size])
|
||||
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
|
||||
if cls_token and extra_tokens > 0:
|
||||
pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
|
||||
return pos_embed
|
||||
|
||||
|
||||
def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
# use half of dimensions to encode grid_h
|
||||
emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
|
||||
emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
|
||||
|
||||
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
|
||||
return emb
|
||||
|
||||
|
||||
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
||||
"""
|
||||
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
||||
"""
|
||||
if embed_dim % 2 != 0:
|
||||
raise ValueError("embed_dim must be divisible by 2")
|
||||
|
||||
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
||||
omega /= embed_dim / 2.0
|
||||
omega = 1.0 / 10000**omega # (D/2,)
|
||||
|
||||
pos = pos.reshape(-1) # (M,)
|
||||
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
||||
|
||||
emb_sin = np.sin(out) # (M, D/2)
|
||||
emb_cos = np.cos(out) # (M, D/2)
|
||||
|
||||
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
||||
return emb
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
"""2D Image to Patch Embedding"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
height=224,
|
||||
width=224,
|
||||
patch_size=16,
|
||||
in_channels=3,
|
||||
embed_dim=768,
|
||||
layer_norm=False,
|
||||
flatten=True,
|
||||
bias=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
num_patches = (height // patch_size) * (width // patch_size)
|
||||
self.flatten = flatten
|
||||
self.layer_norm = layer_norm
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
||||
)
|
||||
if layer_norm:
|
||||
self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
|
||||
else:
|
||||
self.norm = None
|
||||
|
||||
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
|
||||
self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
|
||||
|
||||
def forward(self, latent):
|
||||
latent = self.proj(latent)
|
||||
if self.flatten:
|
||||
latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||
if self.layer_norm:
|
||||
latent = self.norm(latent)
|
||||
return latent + self.pos_embed
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
|
||||
super().__init__()
|
||||
@@ -198,3 +288,58 @@ class ImagePositionalEmbeddings(nn.Module):
|
||||
emb = emb + pos_emb[:, : emb.shape[1], :]
|
||||
|
||||
return emb
|
||||
|
||||
|
||||
class LabelEmbedding(nn.Module):
|
||||
"""
|
||||
Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
|
||||
|
||||
Args:
|
||||
num_classes (`int`): The number of classes.
|
||||
hidden_size (`int`): The size of the vector embeddings.
|
||||
dropout_prob (`float`): The probability of dropping a label.
|
||||
"""
|
||||
|
||||
def __init__(self, num_classes, hidden_size, dropout_prob):
|
||||
super().__init__()
|
||||
use_cfg_embedding = dropout_prob > 0
|
||||
self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
|
||||
self.num_classes = num_classes
|
||||
self.dropout_prob = dropout_prob
|
||||
|
||||
def token_drop(self, labels, force_drop_ids=None):
|
||||
"""
|
||||
Drops labels to enable classifier-free guidance.
|
||||
"""
|
||||
if force_drop_ids is None:
|
||||
drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
|
||||
else:
|
||||
drop_ids = torch.tensor(force_drop_ids == 1)
|
||||
labels = torch.where(drop_ids, self.num_classes, labels)
|
||||
return labels
|
||||
|
||||
def forward(self, labels, force_drop_ids=None):
|
||||
use_dropout = self.dropout_prob > 0
|
||||
if (self.training and use_dropout) or (force_drop_ids is not None):
|
||||
labels = self.token_drop(labels, force_drop_ids)
|
||||
embeddings = self.embedding_table(labels)
|
||||
return embeddings
|
||||
|
||||
|
||||
class CombinedTimestepLabelEmbeddings(nn.Module):
|
||||
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
|
||||
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
||||
self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob)
|
||||
|
||||
def forward(self, timestep, class_labels, hidden_dtype=None):
|
||||
timesteps_proj = self.time_proj(timestep)
|
||||
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
|
||||
|
||||
class_labels = self.class_embedder(class_labels) # (N, D)
|
||||
|
||||
conditioning = timesteps_emb + class_labels # (N, D)
|
||||
|
||||
return conditioning
|
||||
|
||||
@@ -20,8 +20,9 @@ from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..models.embeddings import ImagePositionalEmbeddings
|
||||
from ..utils import BaseOutput
|
||||
from ..utils import BaseOutput, deprecate
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import PatchEmbed
|
||||
from .modeling_utils import ModelMixin
|
||||
|
||||
|
||||
@@ -81,6 +82,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
num_attention_heads: int = 16,
|
||||
attention_head_dim: int = 88,
|
||||
in_channels: Optional[int] = None,
|
||||
out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
@@ -88,11 +90,14 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
attention_bias: bool = False,
|
||||
sample_size: Optional[int] = None,
|
||||
num_vector_embeds: Optional[int] = None,
|
||||
patch_size: Optional[int] = None,
|
||||
activation_fn: str = "geglu",
|
||||
num_embeds_ada_norm: Optional[int] = None,
|
||||
use_linear_projection: bool = False,
|
||||
only_cross_attention: bool = False,
|
||||
upcast_attention: bool = False,
|
||||
norm_type: str = "layer_norm",
|
||||
norm_elementwise_affine: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.use_linear_projection = use_linear_projection
|
||||
@@ -102,18 +107,35 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
|
||||
# Define whether input is continuous or discrete depending on configuration
|
||||
self.is_input_continuous = in_channels is not None
|
||||
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
|
||||
self.is_input_vectorized = num_vector_embeds is not None
|
||||
self.is_input_patches = in_channels is not None and patch_size is not None
|
||||
|
||||
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
|
||||
deprecation_message = (
|
||||
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
|
||||
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
|
||||
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
|
||||
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
|
||||
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
|
||||
)
|
||||
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
|
||||
norm_type = "ada_norm"
|
||||
|
||||
if self.is_input_continuous and self.is_input_vectorized:
|
||||
raise ValueError(
|
||||
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized:
|
||||
elif self.is_input_vectorized and self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
|
||||
" sure that either `in_channels` or `num_vector_embeds` is not None."
|
||||
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
|
||||
" sure that either `num_vector_embeds` or `num_patches` is None."
|
||||
)
|
||||
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
|
||||
raise ValueError(
|
||||
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
|
||||
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
|
||||
)
|
||||
|
||||
# 2. Define input layers
|
||||
@@ -137,6 +159,20 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self.latent_image_embedding = ImagePositionalEmbeddings(
|
||||
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
|
||||
)
|
||||
elif self.is_input_patches:
|
||||
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
|
||||
|
||||
self.height = sample_size
|
||||
self.width = sample_size
|
||||
|
||||
self.patch_size = patch_size
|
||||
self.pos_embed = PatchEmbed(
|
||||
height=sample_size,
|
||||
width=sample_size,
|
||||
patch_size=patch_size,
|
||||
in_channels=in_channels,
|
||||
embed_dim=inner_dim,
|
||||
)
|
||||
|
||||
# 3. Define transformers blocks
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
@@ -152,13 +188,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
attention_bias=attention_bias,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
norm_type=norm_type,
|
||||
norm_elementwise_affine=norm_elementwise_affine,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
# 4. Define output layers
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
if self.is_input_continuous:
|
||||
# TODO: should use out_channels for continous projections
|
||||
if use_linear_projection:
|
||||
self.proj_out = nn.Linear(in_channels, inner_dim)
|
||||
else:
|
||||
@@ -166,12 +206,17 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
elif self.is_input_vectorized:
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
elif self.is_input_patches:
|
||||
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
|
||||
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
|
||||
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
timestep=None,
|
||||
class_labels=None,
|
||||
cross_attention_kwargs=None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
@@ -185,6 +230,9 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
|
||||
Optional class labels to be applied as an embedding in AdaLayerZeroNorm. Used to indicate class labels
|
||||
conditioning.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -195,7 +243,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
# 1. Input
|
||||
if self.is_input_continuous:
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
batch, _, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
@@ -209,6 +257,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
elif self.is_input_vectorized:
|
||||
hidden_states = self.latent_image_embedding(hidden_states)
|
||||
elif self.is_input_patches:
|
||||
hidden_states = self.pos_embed(hidden_states)
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
@@ -217,6 +267,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
timestep=timestep,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
class_labels=class_labels,
|
||||
)
|
||||
|
||||
# 3. Output
|
||||
@@ -237,6 +288,24 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# log(p(x_0))
|
||||
output = F.log_softmax(logits.double(), dim=1).float()
|
||||
elif self.is_input_patches:
|
||||
# TODO: cleanup!
|
||||
conditioning = self.transformer_blocks[0].norm1.emb(
|
||||
timestep, class_labels, hidden_dtype=hidden_states.dtype
|
||||
)
|
||||
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
|
||||
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
|
||||
hidden_states = self.proj_out_2(hidden_states)
|
||||
|
||||
# unpatchify
|
||||
height = width = int(hidden_states.shape[1] ** 0.5)
|
||||
hidden_states = hidden_states.reshape(
|
||||
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
|
||||
)
|
||||
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
|
||||
output = hidden_states.reshape(
|
||||
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
|
||||
)
|
||||
|
||||
if not return_dict:
|
||||
return (output,)
|
||||
|
||||
@@ -18,6 +18,7 @@ else:
|
||||
from .dance_diffusion import DanceDiffusionPipeline
|
||||
from .ddim import DDIMPipeline
|
||||
from .ddpm import DDPMPipeline
|
||||
from .dit import DiTPipeline
|
||||
from .latent_diffusion import LDMSuperResolutionPipeline
|
||||
from .latent_diffusion_uncond import LDMPipeline
|
||||
from .pipeline_utils import AudioPipelineOutput, DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
@@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -91,14 +84,7 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: RobertaSeriesModelWithTransformation,
|
||||
tokenizer: XLMRobertaTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -25,14 +25,7 @@ from transformers import CLIPFeatureExtractor, XLMRobertaTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
@@ -129,14 +122,7 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder: RobertaSeriesModelWithTransformation,
|
||||
tokenizer: XLMRobertaTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
1
src/diffusers/pipelines/dit/__init__.py
Normal file
1
src/diffusers/pipelines/dit/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
from .pipeline_dit import DiTPipeline
|
||||
199
src/diffusers/pipelines/dit/pipeline_dit.py
Normal file
199
src/diffusers/pipelines/dit/pipeline_dit.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Attribution-NonCommercial 4.0 International (CC BY-NC 4.0)
|
||||
# William Peebles and Saining Xie
|
||||
#
|
||||
# Copyright (c) 2021 OpenAI
|
||||
# MIT License
|
||||
#
|
||||
# Copyright 2022 The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from ...models import AutoencoderKL, Transformer2DModel
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
|
||||
class DiTPipeline(DiffusionPipeline):
|
||||
r"""
|
||||
This pipeline inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
|
||||
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
||||
|
||||
Parameters:
|
||||
transformer ([`Transformer2DModel`]):
|
||||
Class conditioned Transformer in Diffusion model to denoise the encoded image latents.
|
||||
vae ([`AutoencoderKL`]):
|
||||
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
|
||||
scheduler ([`DDIMScheduler`]):
|
||||
A scheduler to be used in combination with `dit` to denoise the encoded image latents.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transformer: Transformer2DModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
id2label: Optional[Dict[int, str]] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(transformer=transformer, vae=vae, scheduler=scheduler)
|
||||
|
||||
# create a imagenet -> id dictionary for easier use
|
||||
self.labels = {}
|
||||
if id2label is not None:
|
||||
for key, value in id2label.items():
|
||||
for label in value.split(","):
|
||||
self.labels[label.lstrip().rstrip()] = int(key)
|
||||
self.labels = dict(sorted(self.labels.items()))
|
||||
|
||||
def get_label_ids(self, label: Union[str, List[str]]) -> List[int]:
|
||||
r"""
|
||||
|
||||
Map label strings, *e.g.* from ImageNet, to corresponding class ids.
|
||||
|
||||
Parameters:
|
||||
label (`str` or `dict` of `str`): label strings to be mapped to class ids.
|
||||
|
||||
Returns:
|
||||
`list` of `int`: Class ids to be processed by pipeline.
|
||||
"""
|
||||
|
||||
if not isinstance(label, list):
|
||||
label = list(label)
|
||||
|
||||
for l in label:
|
||||
if l not in self.labels:
|
||||
raise ValueError(
|
||||
f"{l} does not exist. Please make sure to select one of the following labels: \n {self.labels}."
|
||||
)
|
||||
|
||||
return [self.labels[l] for l in label]
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
class_labels: List[int],
|
||||
guidance_scale: float = 4.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
num_inference_steps: int = 50,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
) -> Union[ImagePipelineOutput, Tuple]:
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
class_labels (List[int]):
|
||||
List of imagenet class labels for the images to be generated.
|
||||
guidance_scale (`float`, *optional*, defaults to 4.0):
|
||||
Scale of the guidance signal.
|
||||
generator (`torch.Generator`, *optional*):
|
||||
A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation
|
||||
deterministic.
|
||||
num_inference_steps (`int`, *optional*, defaults to 250):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`ImagePipelineOutput`] instead of a plain tuple.
|
||||
"""
|
||||
|
||||
batch_size = len(class_labels)
|
||||
latent_size = self.transformer.config.sample_size
|
||||
latent_channels = self.transformer.config.in_channels
|
||||
|
||||
latents = randn_tensor(
|
||||
shape=(batch_size, latent_channels, latent_size, latent_size),
|
||||
generator=generator,
|
||||
device=self.device,
|
||||
dtype=self.transformer.dtype,
|
||||
)
|
||||
latent_model_input = torch.cat([latents] * 2) if guidance_scale > 1 else latents
|
||||
|
||||
class_labels = torch.tensor(class_labels, device=self.device).reshape(-1)
|
||||
class_null = torch.tensor([1000] * batch_size, device=self.device)
|
||||
class_labels_input = torch.cat([class_labels, class_null], 0) if guidance_scale > 1 else class_labels
|
||||
|
||||
# set step values
|
||||
self.scheduler.set_timesteps(num_inference_steps)
|
||||
|
||||
for t in self.progress_bar(self.scheduler.timesteps):
|
||||
if guidance_scale > 1:
|
||||
half = latent_model_input[: len(latent_model_input) // 2]
|
||||
latent_model_input = torch.cat([half, half], dim=0)
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
timesteps = t
|
||||
if not torch.is_tensor(timesteps):
|
||||
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
||||
# This would be a good case for the `match` statement (Python 3.10+)
|
||||
is_mps = latent_model_input.device.type == "mps"
|
||||
if isinstance(timesteps, float):
|
||||
dtype = torch.float32 if is_mps else torch.float64
|
||||
else:
|
||||
dtype = torch.int32 if is_mps else torch.int64
|
||||
timesteps = torch.tensor([timesteps], dtype=dtype, device=latent_model_input.device)
|
||||
elif len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(latent_model_input.device)
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps.expand(latent_model_input.shape[0])
|
||||
# predict noise model_output
|
||||
noise_pred = self.transformer(
|
||||
latent_model_input, timestep=timesteps, class_labels=class_labels_input
|
||||
).sample
|
||||
|
||||
# perform guidance
|
||||
if guidance_scale > 1:
|
||||
eps, rest = noise_pred[:, :latent_channels], noise_pred[:, latent_channels:]
|
||||
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
|
||||
|
||||
half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
|
||||
eps = torch.cat([half_eps, half_eps], dim=0)
|
||||
|
||||
noise_pred = torch.cat([eps, rest], dim=1)
|
||||
|
||||
# learned sigma
|
||||
if self.transformer.config.out_channels // 2 == latent_channels:
|
||||
model_output, _ = torch.split(noise_pred, latent_channels, dim=1)
|
||||
else:
|
||||
model_output = noise_pred
|
||||
|
||||
# compute previous image: x_t -> x_t-1
|
||||
latent_model_input = self.scheduler.step(model_output, t, latent_model_input).prev_sample
|
||||
|
||||
if guidance_scale > 1:
|
||||
latents, _ = latent_model_input.chunk(2, dim=0)
|
||||
else:
|
||||
latents = latent_model_input
|
||||
|
||||
latents = 1 / 0.18215 * latents
|
||||
samples = self.vae.decode(latents).sample
|
||||
|
||||
samples = (samples / 2 + 0.5).clamp(0, 1)
|
||||
|
||||
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||||
samples = samples.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
samples = self.numpy_to_pil(samples)
|
||||
|
||||
if not return_dict:
|
||||
return (samples,)
|
||||
|
||||
return ImagePipelineOutput(images=samples)
|
||||
@@ -22,14 +22,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor, replace_example_docstring
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
@@ -88,14 +81,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -25,14 +25,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTF
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
@@ -91,14 +84,7 @@ class StableDiffusionDepth2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
depth_estimator: DPTForDepthEstimation,
|
||||
feature_extractor: DPTFeatureExtractor,
|
||||
):
|
||||
|
||||
@@ -23,14 +23,7 @@ from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
@@ -73,14 +66,7 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
vae: AutoencoderKL,
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import (
|
||||
PIL_INTERPOLATION,
|
||||
deprecate,
|
||||
@@ -133,14 +126,7 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -24,7 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
@@ -173,7 +173,7 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -24,14 +24,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionPipelineOutput
|
||||
@@ -100,14 +93,7 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
PNDMScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: StableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -22,7 +22,7 @@ import PIL
|
||||
from transformers import CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import DDPMScheduler, KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
@@ -84,7 +84,7 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
low_res_scheduler: DDPMScheduler,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
max_noise_level: int = 350,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -10,14 +10,7 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from ...configuration_utils import FrozenDict
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import (
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
)
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import deprecate, is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from . import StableDiffusionSafePipelineOutput
|
||||
@@ -65,14 +58,7 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
||||
text_encoder: CLIPTextModel,
|
||||
tokenizer: CLIPTokenizer,
|
||||
unet: UNet2DConditionModel,
|
||||
scheduler: Union[
|
||||
DDIMScheduler,
|
||||
DPMSolverMultistepScheduler,
|
||||
EulerAncestralDiscreteScheduler,
|
||||
EulerDiscreteScheduler,
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
safety_checker: SafeStableDiffusionSafetyChecker,
|
||||
feature_extractor: CLIPFeatureExtractor,
|
||||
requires_safety_checker: bool = True,
|
||||
|
||||
@@ -7,7 +7,7 @@ import PIL.Image
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer, CLIPVisionModel
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import logging
|
||||
from ..pipeline_utils import DiffusionPipeline
|
||||
from .pipeline_versatile_diffusion_dual_guided import VersatileDiffusionDualGuidedPipeline
|
||||
@@ -53,7 +53,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel
|
||||
text_unet: UNet2DConditionModel
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -64,7 +64,7 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel,
|
||||
text_unet: UNet2DConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
@@ -28,7 +28,7 @@ from transformers import (
|
||||
)
|
||||
|
||||
from ...models import AutoencoderKL, DualTransformer2DModel, Transformer2DModel, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
@@ -62,7 +62,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel
|
||||
text_unet: UNetFlatConditionModel
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
|
||||
_optional_components = ["text_unet"]
|
||||
|
||||
@@ -75,7 +75,7 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel,
|
||||
text_unet: UNetFlatConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -23,7 +23,7 @@ import PIL
|
||||
from transformers import CLIPFeatureExtractor, CLIPVisionModelWithProjection
|
||||
|
||||
from ...models import AutoencoderKL, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
|
||||
@@ -53,7 +53,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
image_encoder: CLIPVisionModelWithProjection
|
||||
image_unet: UNet2DConditionModel
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -61,7 +61,7 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
||||
image_encoder: CLIPVisionModelWithProjection,
|
||||
image_unet: UNet2DConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -21,7 +21,7 @@ import torch.utils.checkpoint
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...models import AutoencoderKL, Transformer2DModel, UNet2DConditionModel
|
||||
from ...schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
|
||||
from ...schedulers import KarrasDiffusionSchedulers
|
||||
from ...utils import is_accelerate_available, logging, randn_tensor
|
||||
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from .modeling_text_unet import UNetFlatConditionModel
|
||||
@@ -54,7 +54,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel
|
||||
text_unet: UNetFlatConditionModel
|
||||
vae: AutoencoderKL
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler]
|
||||
scheduler: KarrasDiffusionSchedulers
|
||||
|
||||
_optional_components = ["text_unet"]
|
||||
|
||||
@@ -65,7 +65,7 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
||||
image_unet: UNet2DConditionModel,
|
||||
text_unet: UNetFlatConditionModel,
|
||||
vae: AutoencoderKL,
|
||||
scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
|
||||
scheduler: KarrasDiffusionSchedulers,
|
||||
):
|
||||
super().__init__()
|
||||
self.register_modules(
|
||||
|
||||
@@ -39,7 +39,7 @@ else:
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
|
||||
try:
|
||||
@@ -55,7 +55,12 @@ else:
|
||||
from .scheduling_lms_discrete_flax import FlaxLMSDiscreteScheduler
|
||||
from .scheduling_pndm_flax import FlaxPNDMScheduler
|
||||
from .scheduling_sde_ve_flax import FlaxScoreSdeVeScheduler
|
||||
from .scheduling_utils_flax import FlaxSchedulerMixin, FlaxSchedulerOutput, broadcast_to_shape_from_left
|
||||
from .scheduling_utils_flax import (
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
)
|
||||
|
||||
|
||||
try:
|
||||
|
||||
@@ -23,8 +23,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -112,7 +112,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ import jax.numpy as jnp
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
add_noise_common,
|
||||
@@ -102,7 +102,7 @@ class FlaxDDIMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@@ -22,8 +22,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, FrozenDict, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, deprecate, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, deprecate, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -105,7 +105,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ import jax.numpy as jnp
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
add_noise_common,
|
||||
@@ -85,7 +85,7 @@ class FlaxDDPMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@@ -22,8 +22,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -106,7 +105,7 @@ class DEISMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -21,8 +21,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, deprecate
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -117,7 +117,7 @@ class DPMSolverMultistepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
order = 1
|
||||
|
||||
|
||||
@@ -24,8 +24,8 @@ import jax.numpy as jnp
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import deprecate
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
add_noise_common,
|
||||
@@ -140,7 +140,7 @@ class FlaxDPMSolverMultistepScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
_deprecated_kwargs = ["predict_epsilon"]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
@@ -21,8 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -116,7 +115,7 @@ class DPMSolverSinglestepScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -19,8 +19,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, logging, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -71,7 +71,7 @@ class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -19,8 +19,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput, logging, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput, logging, randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
@@ -72,7 +72,7 @@ class EulerDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -18,8 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -48,7 +47,7 @@ class HeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -18,8 +18,8 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, randn_tensor
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from ..utils import randn_tensor
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -49,7 +49,7 @@ class KDPM2AncestralDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -18,8 +18,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
@@ -49,7 +48,7 @@ class KDPM2DiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 2
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -21,8 +21,8 @@ import torch
|
||||
from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -70,7 +70,7 @@ class LMSDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
||||
https://imagen.research.google/video/paper.pdf)
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -21,8 +21,8 @@ from scipy import integrate
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
broadcast_to_shape_from_left,
|
||||
@@ -82,7 +82,7 @@ class FlaxLMSDiscreteScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
|
||||
dtype: jnp.dtype
|
||||
|
||||
|
||||
@@ -21,8 +21,7 @@ import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS
|
||||
from .scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
@@ -92,7 +91,7 @@ class PNDMScheduler(SchedulerMixin, ConfigMixin):
|
||||
|
||||
"""
|
||||
|
||||
_compatibles = _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in KarrasDiffusionSchedulers]
|
||||
order = 1
|
||||
|
||||
@register_to_config
|
||||
|
||||
@@ -23,8 +23,8 @@ import jax.numpy as jnp
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from .scheduling_utils_flax import (
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CommonSchedulerState,
|
||||
FlaxKarrasDiffusionSchedulers,
|
||||
FlaxSchedulerMixin,
|
||||
FlaxSchedulerOutput,
|
||||
add_noise_common,
|
||||
@@ -110,7 +110,7 @@ class FlaxPNDMScheduler(FlaxSchedulerMixin, ConfigMixin):
|
||||
the `dtype` used for params and computation.
|
||||
"""
|
||||
|
||||
_compatibles = _FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS.copy()
|
||||
_compatibles = [e.name for e in FlaxKarrasDiffusionSchedulers]
|
||||
|
||||
dtype: jnp.dtype
|
||||
pndm_order: int
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
import importlib
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
import torch
|
||||
@@ -24,6 +25,21 @@ from ..utils import BaseOutput
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
|
||||
|
||||
class KarrasDiffusionSchedulers(Enum):
|
||||
DDIMScheduler = 1
|
||||
DDPMScheduler = 2
|
||||
PNDMScheduler = 3
|
||||
LMSDiscreteScheduler = 4
|
||||
EulerDiscreteScheduler = 5
|
||||
HeunDiscreteScheduler = 6
|
||||
EulerAncestralDiscreteScheduler = 7
|
||||
DPMSolverMultistepScheduler = 8
|
||||
DPMSolverSinglestepScheduler = 9
|
||||
KDPM2DiscreteScheduler = 10
|
||||
KDPM2AncestralDiscreteScheduler = 11
|
||||
DEISMultistepScheduler = 12
|
||||
|
||||
|
||||
@dataclass
|
||||
class SchedulerOutput(BaseOutput):
|
||||
"""
|
||||
|
||||
@@ -15,16 +15,24 @@ import importlib
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional, Tuple, Union
|
||||
|
||||
import flax
|
||||
import jax.numpy as jnp
|
||||
|
||||
from ..utils import _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS, BaseOutput
|
||||
from ..utils import BaseOutput
|
||||
|
||||
|
||||
SCHEDULER_CONFIG_NAME = "scheduler_config.json"
|
||||
_FLAX_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = ["Flax" + c for c in _COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS]
|
||||
|
||||
|
||||
class FlaxKarrasDiffusionSchedulers(Enum):
|
||||
FlaxDDIMScheduler = 1
|
||||
FlaxDDPMScheduler = 2
|
||||
FlaxPNDMScheduler = 3
|
||||
FlaxLMSDiscreteScheduler = 4
|
||||
FlaxDPMSolverMultistepScheduler = 5
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -19,7 +19,6 @@ from packaging import version
|
||||
|
||||
from .. import __version__
|
||||
from .constants import (
|
||||
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS,
|
||||
CONFIG_NAME,
|
||||
DIFFUSERS_CACHE,
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
||||
|
||||
@@ -30,18 +30,3 @@ HUGGINGFACE_CO_RESOLVE_ENDPOINT = "https://huggingface.co"
|
||||
DIFFUSERS_CACHE = default_cache_path
|
||||
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
||||
HF_MODULES_CACHE = os.getenv("HF_MODULES_CACHE", os.path.join(hf_cache_home, "modules"))
|
||||
|
||||
_COMPATIBLE_STABLE_DIFFUSION_SCHEDULERS = [
|
||||
"DDIMScheduler",
|
||||
"DDPMScheduler",
|
||||
"PNDMScheduler",
|
||||
"LMSDiscreteScheduler",
|
||||
"EulerDiscreteScheduler",
|
||||
"HeunDiscreteScheduler",
|
||||
"EulerAncestralDiscreteScheduler",
|
||||
"DPMSolverMultistepScheduler",
|
||||
"DPMSolverSinglestepScheduler",
|
||||
"KDPM2DiscreteScheduler",
|
||||
"KDPM2AncestralDiscreteScheduler",
|
||||
"DEISMultistepScheduler",
|
||||
]
|
||||
|
||||
@@ -227,6 +227,21 @@ class DiffusionPipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class DiTPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class ImagePipelineOutput(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
0
tests/pipelines/dit/__init__.py
Normal file
0
tests/pipelines/dit/__init__.py
Normal file
135
tests/pipelines/dit/test_dit.py
Normal file
135
tests/pipelines/dit/test_dit.py
Normal file
@@ -0,0 +1,135 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import AutoencoderKL, DDIMScheduler, DiTPipeline, DPMSolverMultistepScheduler, Transformer2DModel
|
||||
from diffusers.utils import load_numpy, slow
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
|
||||
from ...test_pipelines_common import PipelineTesterMixin
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class DiTPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
|
||||
pipeline_class = DiTPipeline
|
||||
test_cpu_offload = False
|
||||
|
||||
def get_dummy_components(self):
|
||||
torch.manual_seed(0)
|
||||
transformer = Transformer2DModel(
|
||||
sample_size=4,
|
||||
num_layers=2,
|
||||
patch_size=2,
|
||||
attention_head_dim=2,
|
||||
num_attention_heads=2,
|
||||
in_channels=4,
|
||||
out_channels=8,
|
||||
attention_bias=True,
|
||||
activation_fn="gelu-approximate",
|
||||
num_embeds_ada_norm=1000,
|
||||
norm_type="ada_norm_zero",
|
||||
norm_elementwise_affine=False,
|
||||
)
|
||||
vae = AutoencoderKL()
|
||||
scheduler = DDIMScheduler()
|
||||
components = {"transformer": transformer.eval(), "vae": vae.eval(), "scheduler": scheduler}
|
||||
return components
|
||||
|
||||
def get_dummy_inputs(self, device, seed=0):
|
||||
if str(device).startswith("mps"):
|
||||
generator = torch.manual_seed(seed)
|
||||
else:
|
||||
generator = torch.Generator(device=device).manual_seed(seed)
|
||||
inputs = {
|
||||
"class_labels": [1],
|
||||
"generator": generator,
|
||||
"num_inference_steps": 2,
|
||||
"output_type": "numpy",
|
||||
}
|
||||
return inputs
|
||||
|
||||
def test_inference(self):
|
||||
device = "cpu"
|
||||
|
||||
components = self.get_dummy_components()
|
||||
pipe = self.pipeline_class(**components)
|
||||
pipe.to(device)
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
inputs = self.get_dummy_inputs(device)
|
||||
image = pipe(**inputs).images
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
|
||||
self.assertEqual(image.shape, (1, 4, 4, 3))
|
||||
expected_slice = np.array(
|
||||
[0.44405967, 0.33592293, 0.6093237, 0.48981372, 0.79098296, 0.7504172, 0.59413105, 0.49462673, 0.35190058]
|
||||
)
|
||||
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
|
||||
self.assertLessEqual(max_diff, 1e-3)
|
||||
|
||||
def test_inference_batch_single_identical(self):
|
||||
self._test_inference_batch_single_identical(relax_max_difference=True)
|
||||
|
||||
|
||||
@require_torch_gpu
|
||||
@slow
|
||||
class DiTPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_dit_256(self):
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
|
||||
pipe.to("cuda")
|
||||
|
||||
words = ["vase", "umbrella", "white shark", "white wolf"]
|
||||
ids = pipe.get_label_ids(words)
|
||||
|
||||
images = pipe(ids, generator=generator, num_inference_steps=40, output_type="np").images
|
||||
|
||||
for word, image in zip(words, images):
|
||||
expected_image = load_numpy(
|
||||
f"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/dit/{word}.npy"
|
||||
)
|
||||
assert np.abs((expected_image - image).sum()) < 1e-3
|
||||
|
||||
def test_dit_512_fp16(self):
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-512", torch_dtype=torch.float16)
|
||||
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
|
||||
pipe.to("cuda")
|
||||
|
||||
words = ["vase", "umbrella", "white shark", "white wolf"]
|
||||
ids = pipe.get_label_ids(words)
|
||||
|
||||
images = pipe(ids, generator=generator, num_inference_steps=25, output_type="np").images
|
||||
|
||||
for word, image in zip(words, images):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
f"/dit/{word}_fp16.npy"
|
||||
)
|
||||
assert np.abs((expected_image - image).sum()) < 1e-3
|
||||
@@ -36,7 +36,7 @@ class PipelineTesterMixin:
|
||||
equivalence of dict and tuple outputs, etc.
|
||||
"""
|
||||
|
||||
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image"]
|
||||
allowed_required_args = ["source_prompt", "prompt", "image", "mask_image", "example_image", "class_labels"]
|
||||
required_optional_params = ["generator", "num_inference_steps", "return_dict"]
|
||||
num_inference_steps_args = ["num_inference_steps"]
|
||||
|
||||
@@ -194,8 +194,8 @@ class PipelineTesterMixin:
|
||||
):
|
||||
if self.pipeline_class.__name__ in ["CycleDiffusionPipeline", "RePaintPipeline"]:
|
||||
# RePaint can hardly be made deterministic since the scheduler is currently always
|
||||
# indeterministic
|
||||
# CycleDiffusion is also slighly undeterministic
|
||||
# nondeterministic
|
||||
# CycleDiffusion is also slightly nondeterministic
|
||||
return
|
||||
|
||||
if test_max_difference is None:
|
||||
@@ -515,7 +515,7 @@ class PipelineTesterMixin:
|
||||
torch_device != "cuda" or not is_xformers_available(),
|
||||
reason="XFormers attention is only available with CUDA and `xformers` installed",
|
||||
)
|
||||
def test_xformers_attention_forward_pass(self):
|
||||
def test_xformers_attention_forwardGenerator_pass(self):
|
||||
if not self.test_xformers_attention:
|
||||
return
|
||||
|
||||
|
||||
Reference in New Issue
Block a user