1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

kakaobrain unCLIP (#1428)

* [wip] attention block updates

* [wip] unCLIP unet decoder and super res

* [wip] unCLIP prior transformer

* [wip] scheduler changes

* [wip] text proj utility class

* [wip] UnCLIPPipeline

* [wip] kakaobrain unCLIP convert script

* [unCLIP pipeline] fixes re: @patrickvonplaten

remove callbacks

move denoising loops into call function

* UNCLIPScheduler re: @patrickvonplaten

Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified
DDPM scheduler with changes to support karlo

* mask -> attention_mask re: @patrickvonplaten

* [DDPMScheduler] remove leftover change

* [docs] PriorTransformer

* [docs] UNet2DConditionModel and UNet2DModel

* [nit] UNCLIPScheduler -> UnCLIPScheduler

matches existing unclip naming better

* [docs] SchedulingUnCLIP

* [docs] UnCLIPTextProjModel

* refactor

* finish licenses

* rename all to attention_mask and prep in models

* more renaming

* don't expose unused configs

* final renaming fixes

* remove x attn mask when not necessary

* configure kakao script to use new class embedding config

* fix copies

* [tests] UnCLIPScheduler

* finish x attn

* finish

* remove more

* rename condition blocks

* clean more

* Apply suggestions from code review

* up

* fix

* [tests] UnCLIPPipelineFastTests

* remove unused imports

* [tests] UnCLIPPipelineIntegrationTests

* correct

* make style

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
Will Berman
2022-12-18 15:15:30 -08:00
committed by GitHub
parent 402b9560b2
commit 2dcf64b72a
21 changed files with 3594 additions and 118 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -25,7 +25,15 @@ except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .modeling_utils import ModelMixin
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
from .models import (
AutoencoderKL,
PriorTransformer,
Transformer2DModel,
UNet1DModel,
UNet2DConditionModel,
UNet2DModel,
VQModel,
)
from .optimization import (
get_constant_schedule,
get_constant_schedule_with_warmup,
@@ -63,6 +71,7 @@ else:
RePaintScheduler,
SchedulerMixin,
ScoreSdeVeScheduler,
UnCLIPScheduler,
VQDiffusionScheduler,
)
from .training_utils import EMAModel
@@ -96,6 +105,7 @@ else:
StableDiffusionPipeline,
StableDiffusionPipelineSafe,
StableDiffusionUpscalePipeline,
UnCLIPPipeline,
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,
VersatileDiffusionPipeline,

View File

@@ -17,6 +17,7 @@ from ..utils import is_flax_available, is_torch_available
if is_torch_available():
from .attention import Transformer2DModel
from .prior_transformer import PriorTransformer
from .unet_1d import UNet1DModel
from .unet_2d import UNet2DModel
from .unet_2d_condition import UNet2DConditionModel

View File

@@ -66,8 +66,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
@@ -181,7 +181,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
# 3. Output
if self.is_input_continuous:
@@ -260,6 +260,8 @@ class AttentionBlock(nn.Module):
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
"""
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
def __init__(
self,
channels: int,
@@ -369,6 +371,7 @@ class AttentionBlock(nn.Module):
# compute next hidden_states
hidden_states = self.proj_attn(hidden_states)
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
# res connect and rescale
@@ -385,7 +388,7 @@ class BasicTransformerBlock(nn.Module):
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm (:
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
@@ -432,7 +435,7 @@ class BasicTransformerBlock(nn.Module):
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
) # is self-attn if context is none
) # is self-attn if encoder_hidden_states is none
else:
self.attn2 = None
@@ -472,23 +475,30 @@ class BasicTransformerBlock(nn.Module):
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def forward(self, hidden_states, context=None, timestep=None):
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
# 1. Self-Attention
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
if self.only_cross_attention:
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
hidden_states = (
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
)
else:
hidden_states = self.attn1(norm_hidden_states) + hidden_states
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None:
# 2. Cross-Attention
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# 3. Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@@ -503,7 +513,7 @@ class CrossAttention(nn.Module):
Parameters:
query_dim (`int`): The number of channels in the query.
cross_attention_dim (`int`, *optional*):
The number of channels in the context. If not given, defaults to `query_dim`.
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
@@ -520,13 +530,18 @@ class CrossAttention(nn.Module):
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
added_kv_proj_dim: Optional[int] = None,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
# for slice_size > 0 the attention score computation
# is split across the batch axis to save memory
@@ -534,11 +549,21 @@ class CrossAttention(nn.Module):
self.sliceable_head_dim = heads
self._slice_size = None
self._use_memory_efficient_attention_xformers = False
self.added_kv_proj_dim = added_kv_proj_dim
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
if self.added_kv_proj_dim is not None:
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
@@ -563,40 +588,58 @@ class CrossAttention(nn.Module):
self._slice_size = slice_size
def forward(self, hidden_states, context=None, mask=None):
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
encoder_hidden_states = encoder_hidden_states
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
context = context if context is not None else hidden_states
key = self.to_k(context)
value = self.to_v(context)
dim = query.shape[-1]
query = self.reshape_heads_to_batch_dim(query)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
if self.added_kv_proj_dim is not None:
key = self.to_k(hidden_states)
value = self.to_v(hidden_states)
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
else:
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.reshape_heads_to_batch_dim(key)
value = self.reshape_heads_to_batch_dim(value)
# attention, what we cannot get enough of
if self._use_memory_efficient_attention_xformers:
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
hidden_states = hidden_states.to(query.dtype)
else:
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
hidden_states = self._attention(query, key, value)
hidden_states = self._attention(query, key, value, attention_mask)
else:
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states
def _attention(self, query, key, value):
def _attention(self, query, key, value, attention_mask=None):
if self.upcast_attention:
query = query.float()
key = key.float()
@@ -608,6 +651,18 @@ class CrossAttention(nn.Module):
beta=0,
alpha=self.scale,
)
if attention_mask is not None:
if attention_mask.shape != attention_scores.shape:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
attention_scores = attention_scores + attention_mask
if self.upcast_softmax:
attention_scores = attention_scores.float()
attention_probs = attention_scores.softmax(dim=-1)
# cast back to the original dtype
@@ -656,7 +711,8 @@ class CrossAttention(nn.Module):
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
def _memory_efficient_attention_xformers(self, query, key, value):
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
@@ -802,7 +858,7 @@ class DualTransformer2DModel(nn.Module):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
@@ -867,17 +923,21 @@ class DualTransformer2DModel(nn.Module):
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
def forward(
self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in CrossAttention
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
@@ -894,9 +954,13 @@ class DualTransformer2DModel(nn.Module):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
0
]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
attention_mask=attention_mask,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]

View File

@@ -0,0 +1,194 @@
from dataclasses import dataclass
from typing import Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
@dataclass
class PriorTransformerOutput(BaseOutput):
"""
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""
predicted_image_embedding: torch.FloatTensor
class PriorTransformer(ModelMixin, ConfigMixin):
"""
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
transformer predicts the image embeddings through a denoising diffusion process.
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
implements for all the models (such as downloading or saving, etc.)
For more details, see the original paper: https://arxiv.org/abs/2204.06125
Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
image embeddings and text embeddings are both the same dimension.
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
length of the prompt after it has been tokenized.
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
num_layers: int = 20,
embedding_dim: int = 768,
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings
self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
self.proj_in = nn.Linear(embedding_dim, inner_dim)
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn="gelu",
attention_bias=True,
)
for d in range(num_layers)
]
)
self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
x_t, the currently predicted image embeddings.
timestep (`torch.long`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.
Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
timesteps_projected = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
proj_embeddings = self.embedding_proj(proj_embedding)
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
hidden_states = self.proj_in(hidden_states)
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
hidden_states = torch.cat(
[
encoder_hidden_states,
proj_embeddings[:, None, :],
time_embeddings[:, None, :],
hidden_states[:, None, :],
prd_embedding,
],
dim=1,
)
hidden_states = hidden_states + positional_embeddings
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -1]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
if not return_dict:
return (predicted_image_embedding,)
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
def post_process_latents(self, prior_latents):
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents

View File

@@ -405,7 +405,14 @@ class ResnetBlock2D(nn.Module):
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
@@ -465,9 +472,16 @@ class ResnetBlock2D(nn.Module):
if temb is not None:
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)

View File

@@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
down_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
types.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
up_block_types (`Tuple[str]`, *optional*, defaults to :
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
block_out_channels (`Tuple[int]`, *optional*, defaults to :
@@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
"""
@register_to_config
@@ -88,6 +92,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
attention_head_dim: int = 8,
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
resnet_time_scale_shift: str = "default",
add_attention: bool = True,
):
super().__init__()
@@ -130,6 +136,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
@@ -140,9 +147,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
resnet_time_scale_shift=resnet_time_scale_shift,
attn_num_head_channels=attention_head_dim,
resnet_groups=norm_num_groups,
add_attention=add_attention,
)
# up
@@ -167,6 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attn_num_head_channels=attention_head_dim,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel

View File

@@ -15,7 +15,7 @@ import numpy as np
import torch
from torch import nn
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
@@ -36,6 +36,7 @@ def get_down_block(
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock2D":
@@ -49,6 +50,19 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "ResnetDownsampleBlock2D":
return ResnetDownsampleBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "AttnDownBlock2D":
return AttnDownBlock2D(
@@ -62,6 +76,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "CrossAttnDownBlock2D":
if cross_attention_dim is None:
@@ -82,6 +97,23 @@ def get_down_block(
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "SimpleCrossAttnDownBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
return SimpleCrossAttnDownBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "SkipDownBlock2D":
return SkipDownBlock2D(
@@ -93,6 +125,7 @@ def get_down_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "AttnSkipDownBlock2D":
return AttnSkipDownBlock2D(
@@ -105,6 +138,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "DownEncoderBlock2D":
return DownEncoderBlock2D(
@@ -116,6 +150,7 @@ def get_down_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif down_block_type == "AttnDownEncoderBlock2D":
return AttnDownEncoderBlock2D(
@@ -128,6 +163,7 @@ def get_down_block(
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
raise ValueError(f"{down_block_type} does not exist.")
@@ -149,6 +185,7 @@ def get_up_block(
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock2D":
@@ -162,6 +199,20 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "ResnetUpsampleBlock2D":
return ResnetUpsampleBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "CrossAttnUpBlock2D":
if cross_attention_dim is None:
@@ -182,6 +233,24 @@ def get_up_block(
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "SimpleCrossAttnUpBlock2D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
return SimpleCrossAttnUpBlock2D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "AttnUpBlock2D":
return AttnUpBlock2D(
@@ -195,6 +264,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "SkipUpBlock2D":
return SkipUpBlock2D(
@@ -206,6 +276,7 @@ def get_up_block(
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "AttnSkipUpBlock2D":
return AttnSkipUpBlock2D(
@@ -218,6 +289,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "UpDecoderBlock2D":
return UpDecoderBlock2D(
@@ -228,6 +300,7 @@ def get_up_block(
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
elif up_block_type == "AttnUpDecoderBlock2D":
return AttnUpDecoderBlock2D(
@@ -239,6 +312,7 @@ def get_up_block(
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
attn_num_head_channels=attn_num_head_channels,
resnet_time_scale_shift=resnet_time_scale_shift,
)
raise ValueError(f"{up_block_type} does not exist.")
@@ -255,14 +329,13 @@ class UNetMidBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
add_attention: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
):
super().__init__()
self.attention_type = attention_type
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.add_attention = add_attention
# there is always at least one resnet
resnets = [
@@ -282,15 +355,19 @@ class UNetMidBlock2D(nn.Module):
attentions = []
for _ in range(num_layers):
attentions.append(
AttentionBlock(
in_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
if self.add_attention:
attentions.append(
AttentionBlock(
in_channels,
num_head_channels=attn_num_head_channels,
rescale_output_factor=output_scale_factor,
eps=resnet_eps,
norm_num_groups=resnet_groups,
)
)
)
else:
attentions.append(None)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
@@ -309,13 +386,11 @@ class UNetMidBlock2D(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_states=None):
def forward(self, hidden_states, temb=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
if self.attention_type == "default":
if attn is not None:
hidden_states = attn(hidden_states)
else:
hidden_states = attn(hidden_states, encoder_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -334,7 +409,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
@@ -344,7 +418,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -408,10 +481,121 @@ class UNetMidBlock2DCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = resnet(hidden_states, temb)
return hidden_states
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels
# there is always at least one resnet
resnets = [
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
CrossAttention(
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
)
)
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
attention_mask=attention_mask,
)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
# resnet
hidden_states = resnet(hidden_states, temb)
return hidden_states
@@ -431,7 +615,6 @@ class AttnDownBlock2D(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
@@ -440,8 +623,6 @@ class AttnDownBlock2D(nn.Module):
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
@@ -514,7 +695,6 @@ class CrossAttnDownBlock2D(nn.Module):
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
@@ -528,7 +708,6 @@ class CrossAttnDownBlock2D(nn.Module):
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
@@ -588,7 +767,8 @@ class CrossAttnDownBlock2D(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
@@ -605,7 +785,9 @@ class CrossAttnDownBlock2D(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -847,7 +1029,6 @@ class AttnSkipDownBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
downsample_padding=1,
add_downsample=True,
@@ -856,8 +1037,6 @@ class AttnSkipDownBlock2D(nn.Module):
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
self.resnets.append(
@@ -1006,6 +1185,205 @@ class SkipDownBlock2D(nn.Module):
return hidden_states, output_states, skip_sample
class ResnetDownsampleBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
down=True,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
output_states += (hidden_states,)
return hidden_states, output_states
class SimpleCrossAttnDownBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_downsample=True,
):
super().__init__()
self.has_cross_attention = True
resnets = []
attentions = []
self.attn_num_head_channels = attn_num_head_channels
self.num_heads = out_channels // self.attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
CrossAttention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
down=True,
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
# resnet
hidden_states = resnet(hidden_states, temb)
# attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
attention_mask=attention_mask,
)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states, temb)
output_states += (hidden_states,)
return hidden_states, output_states
class AttnUpBlock2D(nn.Module):
def __init__(
self,
@@ -1020,7 +1398,6 @@ class AttnUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attention_type="default",
attn_num_head_channels=1,
output_scale_factor=1.0,
add_upsample=True,
@@ -1029,8 +1406,6 @@ class AttnUpBlock2D(nn.Module):
resnets = []
attentions = []
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1100,7 +1475,6 @@ class CrossAttnUpBlock2D(nn.Module):
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
@@ -1113,7 +1487,6 @@ class CrossAttnUpBlock2D(nn.Module):
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
@@ -1176,7 +1549,9 @@ class CrossAttnUpBlock2D(nn.Module):
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
# TODO(Patrick, William) - attention mask is not used
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1196,7 +1571,9 @@ class CrossAttnUpBlock2D(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -1418,7 +1795,6 @@ class AttnSkipUpBlock2D(nn.Module):
resnet_act_fn: str = "swish",
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=np.sqrt(2.0),
upsample_padding=1,
add_upsample=True,
@@ -1427,8 +1803,6 @@ class AttnSkipUpBlock2D(nn.Module):
self.attentions = nn.ModuleList([])
self.resnets = nn.ModuleList([])
self.attention_type = attention_type
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
@@ -1608,3 +1982,213 @@ class SkipUpBlock2D(nn.Module):
hidden_states = self.resnet_up(hidden_states, temb)
return hidden_states, skip_sample
class ResnetUpsampleBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
up=True,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
for resnet in self.resnets:
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
return hidden_states
class SimpleCrossAttnUpBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
):
super().__init__()
resnets = []
attentions = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
self.num_heads = out_channels // self.attn_num_head_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock2D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
attentions.append(
CrossAttention(
query_dim=out_channels,
cross_attention_dim=out_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
if add_upsample:
self.upsamplers = nn.ModuleList(
[
ResnetBlock2D(
in_channels=out_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
up=True,
)
]
)
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# resnet
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
# attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
attention_mask=attention_mask,
)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, temb)
return hidden_states

View File

@@ -27,6 +27,7 @@ from .unet_2d_blocks import (
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UNetMidBlock2DSimpleCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
@@ -66,6 +67,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
@@ -78,6 +81,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
"""
_supports_gradient_checkpointing = True
@@ -97,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
"CrossAttnDownBlock2D",
"DownBlock2D",
),
mid_block_type: str = "UNetMidBlock2DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
@@ -110,8 +118,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
):
super().__init__()
@@ -128,8 +138,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
@@ -165,24 +181,40 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
if mid_block_type == "UNetMidBlock2DCrossAttn":
self.mid_block = UNetMidBlock2DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the images
self.num_upsamplers = 0
@@ -223,6 +255,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -307,6 +340,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
@@ -336,6 +370,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -365,9 +404,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
@@ -382,6 +425,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
@@ -389,7 +433,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
@@ -410,6 +456,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(

View File

@@ -53,6 +53,7 @@ else:
StableDiffusionUpscalePipeline,
)
from .stable_diffusion_safe import StableDiffusionPipelineSafe
from .unclip import UnCLIPPipeline
from .versatile_diffusion import (
VersatileDiffusionDualGuidedPipeline,
VersatileDiffusionImageVariationPipeline,

View File

@@ -0,0 +1,6 @@
from ...utils import is_torch_available, is_transformers_available
if is_transformers_available() and is_torch_available():
from .pipeline_unclip import UnCLIPPipeline
from .text_proj import UnCLIPTextProjModel

View File

@@ -0,0 +1,370 @@
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Union
import torch
from torch.nn import functional as F
from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
from diffusers.schedulers import UnCLIPScheduler
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from ...utils import logging
from .text_proj import UnCLIPTextProjModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class UnCLIPPipeline(DiffusionPipeline):
prior: PriorTransformer
decoder: UNet2DConditionModel
text_proj: UnCLIPTextProjModel
text_encoder: CLIPTextModelWithProjection
tokenizer: CLIPTokenizer
super_res_first: UNet2DModel
super_res_last: UNet2DModel
prior_scheduler: UnCLIPScheduler
decoder_scheduler: UnCLIPScheduler
super_res_scheduler: UnCLIPScheduler
def __init__(
self,
prior: PriorTransformer,
decoder: UNet2DConditionModel,
text_encoder: CLIPTextModelWithProjection,
tokenizer: CLIPTokenizer,
text_proj: UnCLIPTextProjModel,
super_res_first: UNet2DModel,
super_res_last: UNet2DModel,
prior_scheduler: UnCLIPScheduler,
decoder_scheduler: UnCLIPScheduler,
super_res_scheduler: UnCLIPScheduler,
):
super().__init__()
self.register_modules(
prior=prior,
decoder=decoder,
text_encoder=text_encoder,
tokenizer=tokenizer,
text_proj=text_proj,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
if latents is None:
if device.type == "mps":
# randn does not work reproducibly on mps
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
else:
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
latents = latents * scheduler.init_noise_sigma
return latents
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
batch_size = len(prompt) if isinstance(prompt, list) else 1
# get prompt text embeddings
text_inputs = self.tokenizer(
prompt,
padding="max_length",
max_length=self.tokenizer.model_max_length,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids
text_mask = text_inputs.attention_mask.bool().to(self.device)
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
logger.warning(
"The following part of your input was truncated because CLIP can only handle sequences up to"
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
)
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
text_encoder_output = self.text_encoder(text_input_ids.to(self.device))
text_embeddings = text_encoder_output.text_embeds
text_encoder_hidden_states = text_encoder_output.last_hidden_state
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
if do_classifier_free_guidance:
uncond_tokens = [""] * batch_size
max_length = text_input_ids.shape[-1]
uncond_input = self.tokenizer(
uncond_tokens,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt",
)
uncond_text_mask = uncond_input.attention_mask.bool().to(self.device)
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device))
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
seq_len = uncond_embeddings.shape[1]
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
seq_len = uncond_text_encoder_hidden_states.shape[1]
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
batch_size * num_images_per_prompt, seq_len, -1
)
# done duplicates
# For classifier free guidance, we need to do two forward passes.
# Here we concatenate the unconditional and text embeddings into a single batch
# to avoid doing two forward passes
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
text_mask = torch.cat([uncond_text_mask, text_mask])
return text_embeddings, text_encoder_hidden_states, text_mask
@torch.no_grad()
def __call__(
self,
prompt: Union[str, List[str]],
num_images_per_prompt: int = 1,
prior_num_inference_steps: int = 25,
decoder_num_inference_steps: int = 25,
super_res_num_inference_steps: int = 7,
generator: Optional[torch.Generator] = None,
prior_latents: Optional[torch.FloatTensor] = None,
decoder_latents: Optional[torch.FloatTensor] = None,
super_res_latents: Optional[torch.FloatTensor] = None,
prior_guidance_scale: float = 4.0,
decoder_guidance_scale: float = 8.0,
output_type: Optional[str] = "pil",
return_dict: bool = True,
):
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
batch_size = batch_size * num_images_per_prompt
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
prompt, num_images_per_prompt, do_classifier_free_guidance
)
# prior
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device)
prior_timesteps_tensor = self.prior_scheduler.timesteps
embedding_dim = self.prior.config.embedding_dim
prior_latents = self.prepare_latents(
(batch_size, embedding_dim),
text_embeddings.dtype,
self.device,
generator,
prior_latents,
self.prior_scheduler,
)
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
predicted_image_embedding = self.prior(
latent_model_input,
timestep=t,
proj_embedding=text_embeddings,
encoder_hidden_states=text_encoder_hidden_states,
attention_mask=text_mask,
).predicted_image_embedding
if do_classifier_free_guidance:
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
predicted_image_embedding_text - predicted_image_embedding_uncond
)
if i + 1 == prior_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = prior_timesteps_tensor[i + 1]
prior_latents = self.prior_scheduler.step(
predicted_image_embedding,
timestep=t,
sample=prior_latents,
generator=generator,
prev_timestep=prev_timestep,
).prev_sample
prior_latents = self.prior.post_process_latents(prior_latents)
image_embeddings = prior_latents
# done prior
# decoder
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
image_embeddings=image_embeddings,
text_embeddings=text_embeddings,
text_encoder_hidden_states=text_encoder_hidden_states,
do_classifier_free_guidance=do_classifier_free_guidance,
)
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device)
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
num_channels_latents = self.decoder.in_channels
height = self.decoder.sample_size
width = self.decoder.sample_size
decoder_latents = self.prepare_latents(
(batch_size, num_channels_latents, height, width),
text_encoder_hidden_states.dtype,
self.device,
generator,
decoder_latents,
self.decoder_scheduler,
)
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
noise_pred = self.decoder(
sample=latent_model_input,
timestep=t,
encoder_hidden_states=text_encoder_hidden_states,
class_labels=additive_clip_time_embeddings,
attention_mask=decoder_text_mask,
).sample
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
if i + 1 == decoder_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = decoder_timesteps_tensor[i + 1]
# compute the previous noisy sample x_t -> x_t-1
decoder_latents = self.decoder_scheduler.step(
noise_pred, t, decoder_latents, prev_timestep=prev_timestep
).prev_sample
decoder_latents = decoder_latents.clamp(-1, 1)
image_small = decoder_latents
# done decoder
# super res
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device)
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
channels = self.super_res_first.in_channels // 2
height = self.super_res_first.sample_size
width = self.super_res_first.sample_size
super_res_latents = self.prepare_latents(
(batch_size, channels, height, width),
image_small.dtype,
self.device,
generator,
super_res_latents,
self.super_res_scheduler,
)
interpolate_antialias = {}
if "antialias" in inspect.signature(F.interpolate).parameters:
interpolate_antialias["antialias"] = True
image_upscaled = F.interpolate(
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
)
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
# no classifier free guidance
if i == super_res_timesteps_tensor.shape[0] - 1:
unet = self.super_res_last
else:
unet = self.super_res_first
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
noise_pred = unet(
sample=latent_model_input,
timestep=t,
).sample
if i + 1 == super_res_timesteps_tensor.shape[0]:
prev_timestep = None
else:
prev_timestep = super_res_timesteps_tensor[i + 1]
# compute the previous noisy sample x_t -> x_t-1
super_res_latents = self.super_res_scheduler.step(
noise_pred, t, super_res_latents, prev_timestep=prev_timestep
).prev_sample
image = super_res_latents
# done super res
# post processing
image = image * 0.5 + 0.5
image = image.clamp(0, 1)
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
if output_type == "pil":
image = self.numpy_to_pil(image)
if not return_dict:
return (image,)
return ImagePipelineOutput(images=image)

View File

@@ -0,0 +1,87 @@
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from diffusers.modeling_utils import ModelMixin
from ...configuration_utils import ConfigMixin, register_to_config
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
"""
Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
decoder.
For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
"""
@register_to_config
def __init__(
self,
*,
clip_extra_context_tokens: int = 4,
clip_embeddings_dim: int = 768,
time_embed_dim: int,
cross_attention_dim,
):
super().__init__()
self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim))
# parameters for additional clip time embeddings
self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim)
self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim)
# parameters for encoder hidden states
self.clip_extra_context_tokens = clip_extra_context_tokens
self.clip_extra_context_tokens_proj = nn.Linear(
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
)
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
if do_classifier_free_guidance:
# Add the classifier free guidance embeddings to the image embeddings
image_embeddings_batch_size = image_embeddings.shape[0]
classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0)
classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand(
image_embeddings_batch_size, -1
)
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
# The image embeddings batch size and the text embeddings batch size are equal
assert image_embeddings.shape[0] == text_embeddings.shape[0]
batch_size = text_embeddings.shape[0]
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
# adding CLIP embeddings to the existing timestep embedding, ...
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
# ... and by projecting CLIP embeddings into four
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
return text_encoder_hidden_states, additive_clip_time_embeddings

View File

@@ -6,8 +6,9 @@ import torch.nn as nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...modeling_utils import ModelMixin
from ...models.attention import DualTransformer2DModel, Transformer2DModel
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
from ...models.embeddings import TimestepEmbedding, Timesteps
from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn
from ...models.unet_2d_condition import UNet2DConditionOutput
from ...utils import logging
@@ -141,6 +142,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
The tuple of downsample blocks to use.
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`.
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
The tuple of upsample blocks to use.
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
@@ -153,6 +156,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
"""
_supports_gradient_checkpointing = True
@@ -172,6 +179,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
"CrossAttnDownBlockFlat",
"DownBlockFlat",
),
mid_block_type: str = "UNetMidBlockFlatCrossAttn",
up_block_types: Tuple[str] = (
"UpBlockFlat",
"CrossAttnUpBlockFlat",
@@ -190,8 +198,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
):
super().__init__()
@@ -208,8 +218,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if num_class_embeds is not None:
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
@@ -245,24 +261,40 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.down_blocks.append(down_block)
# mid
self.mid_block = UNetMidBlockFlatCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift="default",
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
if mid_block_type == "UNetMidBlockFlatCrossAttn":
self.mid_block = UNetMidBlockFlatCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
)
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the images
self.num_upsamplers = 0
@@ -303,6 +335,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
@@ -387,6 +420,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet2DConditionOutput, Tuple]:
r"""
@@ -416,6 +450,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# 0. center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
@@ -445,9 +484,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.config.num_class_embeds is not None:
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
@@ -462,6 +505,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
@@ -469,7 +513,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
down_block_res_samples += res_samples
# 4. mid
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# 5. up
for i, upsample_block in enumerate(self.up_blocks):
@@ -490,6 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
@@ -715,7 +762,6 @@ class CrossAttnDownBlockFlat(nn.Module):
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
@@ -729,7 +775,6 @@ class CrossAttnDownBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
@@ -789,7 +834,8 @@ class CrossAttnDownBlockFlat(nn.Module):
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
# TODO(Patrick, William) - attention mask is not used
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
@@ -806,7 +852,9 @@ class CrossAttnDownBlockFlat(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -915,7 +963,6 @@ class CrossAttnUpBlockFlat(nn.Module):
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
attention_type="default",
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
@@ -928,7 +975,6 @@ class CrossAttnUpBlockFlat(nn.Module):
attentions = []
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
@@ -991,7 +1037,9 @@ class CrossAttnUpBlockFlat(nn.Module):
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
# TODO(Patrick, William) - attention mask is not used
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
@@ -1011,7 +1059,9 @@ class CrossAttnUpBlockFlat(nn.Module):
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
)[0]
else:
hidden_states = resnet(hidden_states, temb)
@@ -1038,7 +1088,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
attention_type="default",
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
@@ -1048,7 +1097,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
super().__init__()
self.has_cross_attention = True
self.attention_type = attention_type
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -1112,10 +1160,122 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
hidden_states = attn(hidden_states, encoder_hidden_states).sample
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
hidden_states = resnet(hidden_states, temb)
return hidden_states
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
self.num_heads = in_channels // self.attn_num_head_channels
# there is always at least one resnet
resnets = [
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
]
attentions = []
for _ in range(num_layers):
attentions.append(
CrossAttention(
query_dim=in_channels,
cross_attention_dim=in_channels,
heads=self.num_heads,
dim_head=attn_num_head_channels,
added_kv_proj_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
bias=True,
upcast_softmax=True,
)
)
resnets.append(
ResnetBlockFlat(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
def set_attention_slice(self, slice_size):
head_dims = self.attn_num_head_channels
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
raise ValueError(
f"Make sure slice_size {slice_size} is a common divisor of "
f"the number of heads used in cross_attention: {head_dims}"
)
if slice_size is not None and slice_size > min(head_dims):
raise ValueError(
f"slice_size {slice_size} has to be smaller or equal to "
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
)
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet in zip(self.attentions, self.resnets[1:]):
# attn
residual = hidden_states
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
attention_mask=attention_mask,
)
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
hidden_states = hidden_states + residual
# resnet
hidden_states = resnet(hidden_states, temb)
return hidden_states

View File

@@ -37,6 +37,7 @@ else:
from .scheduling_repaint import RePaintScheduler
from .scheduling_sde_ve import ScoreSdeVeScheduler
from .scheduling_sde_vp import ScoreSdeVpScheduler
from .scheduling_unclip import UnCLIPScheduler
from .scheduling_utils import SchedulerMixin
from .scheduling_vq_diffusion import VQDiffusionScheduler

View File

@@ -0,0 +1,309 @@
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
import numpy as np
import torch
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .scheduling_utils import SchedulerMixin
@dataclass
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP
class UnCLIPSchedulerOutput(BaseOutput):
"""
Output class for the scheduler's step function output.
Args:
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
denoising loop.
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
`pred_original_sample` can be used to preview progress or for guidance.
"""
prev_sample: torch.FloatTensor
pred_original_sample: Optional[torch.FloatTensor] = None
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
"""
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
(1-beta) over time from t = [0,1].
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
to that part of the diffusion process.
Args:
num_diffusion_timesteps (`int`): the number of betas to produce.
max_beta (`float`): the maximum beta to use; use values lower than 1 to
prevent singularities.
Returns:
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
"""
def alpha_bar(time_step):
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
betas = []
for i in range(num_diffusion_timesteps):
t1 = i / num_diffusion_timesteps
t2 = (i + 1) / num_diffusion_timesteps
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
return torch.tensor(betas, dtype=torch.float32)
class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
"""
This is a modified DDPM Scheduler specifically for the karlo unCLIP model.
This scheduler has some minor variations in how it calculates the learned range variance and dynamically
re-calculates betas based off the timesteps it is skipping.
The scheduler also uses a slightly different step ratio when computing timesteps to use for inference.
See [`~DDPMScheduler`] for more information on DDPM scheduling
Args:
num_train_timesteps (`int`): number of diffusion steps used to train the model.
variance_type (`str`):
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log`
or `learned_range`.
clip_sample (`bool`, default `True`):
option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical
stability.
clip_sample_range (`float`, default `1.0`):
The range to clip the sample between. See `clip_sample`.
prediction_type (`str`, default `epsilon`, optional):
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process)
or `sample` (directly predicting the noisy sample`)
"""
@register_to_config
def __init__(
self,
num_train_timesteps: int = 1000,
variance_type: str = "fixed_small_log",
clip_sample: bool = True,
clip_sample_range: Optional[float] = 1.0,
prediction_type: str = "epsilon",
):
# beta scheduler is "squaredcos_cap_v2"
self.betas = betas_for_alpha_bar(num_train_timesteps)
self.alphas = 1.0 - self.betas
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
self.one = torch.tensor(1.0)
# standard deviation of the initial noise distribution
self.init_noise_sigma = 1.0
# setable values
self.num_inference_steps = None
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
self.variance_type = variance_type
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
"""
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
current timestep.
Args:
sample (`torch.FloatTensor`): input sample
timestep (`int`, optional): current timestep
Returns:
`torch.FloatTensor`: scaled input sample
"""
return sample
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
"""
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The
different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy
of the results.
Args:
num_inference_steps (`int`):
the number of diffusion steps used when generating samples with a pre-trained model.
"""
self.num_inference_steps = num_inference_steps
step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1)
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
self.timesteps = torch.from_numpy(timesteps).to(device)
def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None):
if prev_timestep is None:
prev_timestep = t - 1
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
if prev_timestep == t - 1:
beta = self.betas[t]
else:
beta = 1 - alpha_prod_t / alpha_prod_t_prev
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
# and sample from it to get previous sample
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
variance = beta_prod_t_prev / beta_prod_t * beta
if variance_type is None:
variance_type = self.config.variance_type
# hacks - were probably added for training stability
if variance_type == "fixed_small_log":
variance = torch.log(torch.clamp(variance, min=1e-20))
variance = torch.exp(0.5 * variance)
elif variance_type == "learned_range":
# NOTE difference with DDPM scheduler
min_log = variance.log()
max_log = beta.log()
frac = (predicted_variance + 1) / 2
variance = frac * max_log + (1 - frac) * min_log
return variance
def step(
self,
model_output: torch.FloatTensor,
timestep: int,
sample: torch.FloatTensor,
prev_timestep: Optional[int] = None,
generator=None,
return_dict: bool = True,
) -> Union[UnCLIPSchedulerOutput, Tuple]:
"""
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
process from the learned model outputs (most often the predicted noise).
Args:
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`torch.FloatTensor`):
current instance of sample being created by diffusion process.
prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
generator: random number generator.
return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class
Returns:
[`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`:
[`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
t = timestep
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
else:
predicted_variance = None
# 1. compute alphas, betas
if prev_timestep is None:
prev_timestep = t - 1
alpha_prod_t = self.alphas_cumprod[t]
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
beta_prod_t = 1 - alpha_prod_t
beta_prod_t_prev = 1 - alpha_prod_t_prev
if prev_timestep == t - 1:
beta = self.betas[t]
alpha = self.alphas[t]
else:
beta = 1 - alpha_prod_t / alpha_prod_t_prev
alpha = 1 - beta
# 2. compute predicted original sample from predicted noise also called
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
if self.config.prediction_type == "epsilon":
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
elif self.config.prediction_type == "sample":
pred_original_sample = model_output
else:
raise ValueError(
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`"
" for the UnCLIPScheduler."
)
# 3. Clip "predicted x_0"
if self.config.clip_sample:
pred_original_sample = torch.clamp(
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
)
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t
current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t
# 5. Compute predicted previous sample µ_t
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
# 6. Add noise
variance = 0
if t > 0:
device = model_output.device
if device.type == "mps":
# randn does not work reproducibly on mps
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
variance_noise = variance_noise.to(device)
else:
variance_noise = torch.randn(
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
)
variance = self._get_variance(
t,
predicted_variance=predicted_variance,
prev_timestep=prev_timestep,
)
if self.variance_type == "fixed_small_log":
variance = variance
elif self.variance_type == "learned_range":
variance = (0.5 * variance).exp()
else:
raise ValueError(
f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`"
" for the UnCLIPScheduler."
)
variance = variance * variance_noise
pred_prev_sample = pred_prev_sample + variance
if not return_dict:
return (pred_prev_sample,)
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)

View File

@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class PriorTransformer(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class Transformer2DModel(metaclass=DummyObject):
_backends = ["torch"]
@@ -512,6 +527,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
requires_backends(cls, ["torch"])
class UnCLIPScheduler(metaclass=DummyObject):
_backends = ["torch"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch"])
class VQDiffusionScheduler(metaclass=DummyObject):
_backends = ["torch"]

View File

@@ -199,6 +199,21 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
requires_backends(cls, ["torch", "transformers"])
class UnCLIPPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]
def __init__(self, *args, **kwargs):
requires_backends(self, ["torch", "transformers"])
@classmethod
def from_config(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
@classmethod
def from_pretrained(cls, *args, **kwargs):
requires_backends(cls, ["torch", "transformers"])
class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
_backends = ["torch", "transformers"]

View File

View File

@@ -0,0 +1,282 @@
# coding=utf-8
# Copyright 2022 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import unittest
import numpy as np
import torch
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.utils import load_numpy, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
torch.backends.cuda.matmul.allow_tf32 = False
class UnCLIPPipelineFastTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
@property
def text_embedder_hidden_size(self):
return 32
@property
def time_input_dim(self):
return 32
@property
def block_out_channels_0(self):
return self.time_input_dim
@property
def time_embed_dim(self):
return self.time_input_dim * 4
@property
def cross_attention_dim(self):
return 100
@property
def dummy_tokenizer(self):
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
return tokenizer
@property
def dummy_text_encoder(self):
torch.manual_seed(0)
config = CLIPTextConfig(
bos_token_id=0,
eos_token_id=2,
hidden_size=self.text_embedder_hidden_size,
projection_dim=self.text_embedder_hidden_size,
intermediate_size=37,
layer_norm_eps=1e-05,
num_attention_heads=4,
num_hidden_layers=5,
pad_token_id=1,
vocab_size=1000,
)
return CLIPTextModelWithProjection(config)
@property
def dummy_prior(self):
torch.manual_seed(0)
model_kwargs = {
"num_attention_heads": 2,
"attention_head_dim": 12,
"embedding_dim": self.text_embedder_hidden_size,
"num_layers": 1,
}
model = PriorTransformer(**model_kwargs)
return model
@property
def dummy_text_proj(self):
torch.manual_seed(0)
model_kwargs = {
"clip_embeddings_dim": self.text_embedder_hidden_size,
"time_embed_dim": self.time_embed_dim,
"cross_attention_dim": self.cross_attention_dim,
}
model = UnCLIPTextProjModel(**model_kwargs)
return model
@property
def dummy_decoder(self):
torch.manual_seed(0)
model_kwargs = {
"sample_size": 64,
# RGB in channels
"in_channels": 3,
# Out channels is double in channels because predicts mean and variance
"out_channels": 6,
"down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
"up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
"layers_per_block": 1,
"cross_attention_dim": self.cross_attention_dim,
"attention_head_dim": 4,
"resnet_time_scale_shift": "scale_shift",
"class_embed_type": "identity",
}
model = UNet2DConditionModel(**model_kwargs)
return model
@property
def dummy_super_res_kwargs(self):
return {
"sample_size": 128,
"layers_per_block": 1,
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
"in_channels": 6,
"out_channels": 3,
}
@property
def dummy_super_res_first(self):
torch.manual_seed(0)
model = UNet2DModel(**self.dummy_super_res_kwargs)
return model
@property
def dummy_super_res_last(self):
# seeded differently to get different unet than `self.dummy_super_res_first`
torch.manual_seed(1)
model = UNet2DModel(**self.dummy_super_res_kwargs)
return model
def test_unclip(self):
device = "cpu"
prior = self.dummy_prior
decoder = self.dummy_decoder
text_proj = self.dummy_text_proj
text_encoder = self.dummy_text_encoder
tokenizer = self.dummy_tokenizer
super_res_first = self.dummy_super_res_first
super_res_last = self.dummy_super_res_last
prior_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="sample",
num_train_timesteps=1000,
clip_sample_range=5.0,
)
decoder_scheduler = UnCLIPScheduler(
variance_type="learned_range",
prediction_type="epsilon",
num_train_timesteps=1000,
)
super_res_scheduler = UnCLIPScheduler(
variance_type="fixed_small_log",
prediction_type="epsilon",
num_train_timesteps=1000,
)
pipe = UnCLIPPipeline(
prior=prior,
decoder=decoder,
text_proj=text_proj,
text_encoder=text_encoder,
tokenizer=tokenizer,
super_res_first=super_res_first,
super_res_last=super_res_last,
prior_scheduler=prior_scheduler,
decoder_scheduler=decoder_scheduler,
super_res_scheduler=super_res_scheduler,
)
pipe = pipe.to(device)
pipe.set_progress_bar_config(disable=None)
prompt = "horse"
generator = torch.Generator(device=device).manual_seed(0)
output = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
)
image = output.images
generator = torch.Generator(device=device).manual_seed(0)
image_from_tuple = pipe(
[prompt],
generator=generator,
prior_num_inference_steps=2,
decoder_num_inference_steps=2,
super_res_num_inference_steps=2,
output_type="np",
return_dict=False,
)[0]
image_slice = image[0, -3:, -3:, -1]
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
assert image.shape == (1, 128, 128, 3)
expected_slice = np.array(
[
0.0009,
0.9997,
0.0003,
0.9991,
0.9967,
0.0003,
0.9997,
0.0003,
0.0004,
]
)
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
@slow
@require_torch_gpu
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
def tearDown(self):
# clean up the VRAM after each test
super().tearDown()
gc.collect()
torch.cuda.empty_cache()
def test_unclip_karlo(self):
expected_image = load_numpy(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
"/karlo_v1_alpha/horse.npy"
)
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
pipeline = pipeline.to(torch_device)
pipeline.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
output = pipeline(
"horse",
num_images_per_prompt=1,
generator=generator,
output_type="np",
)
image = output.images[0]
assert image.shape == (256, 256, 3)
assert np.abs(expected_image - image).max() < 1e-2

View File

@@ -38,6 +38,7 @@ from diffusers import (
LMSDiscreteScheduler,
PNDMScheduler,
ScoreSdeVeScheduler,
UnCLIPScheduler,
VQDiffusionScheduler,
logging,
)
@@ -2643,3 +2644,135 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
# CUDA
assert abs(result_sum.item() - 13913.0332) < 1e-1
assert abs(result_mean.item() - 18.1159) < 1e-3
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
class UnCLIPSchedulerTest(SchedulerCommonTest):
scheduler_classes = (UnCLIPScheduler,)
def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1000,
"variance_type": "fixed_small_log",
"clip_sample": True,
"clip_sample_range": 1.0,
"prediction_type": "epsilon",
}
config.update(**kwargs)
return config
def test_timesteps(self):
for timesteps in [1, 5, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)
def test_variance_type(self):
for variance in ["fixed_small_log", "learned_range"]:
self.check_over_configs(variance_type=variance)
def test_clip_sample(self):
for clip_sample in [True, False]:
self.check_over_configs(clip_sample=clip_sample)
def test_clip_sample_range(self):
for clip_sample_range in [1, 5, 10, 20]:
self.check_over_configs(clip_sample_range=clip_sample_range)
def test_prediction_type(self):
for prediction_type in ["epsilon", "sample"]:
self.check_over_configs(prediction_type=prediction_type)
def test_time_indices(self):
for time_step in [0, 500, 999]:
for prev_timestep in [None, 5, 100, 250, 500, 750]:
if prev_timestep is not None and prev_timestep >= time_step:
continue
self.check_over_forward(time_step=time_step, prev_timestep=prev_timestep)
def test_variance_fixed_small_log(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(variance_type="fixed_small_log")
scheduler = scheduler_class(**scheduler_config)
assert torch.sum(torch.abs(scheduler._get_variance(0) - 1.0000e-10)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.0549625)) < 1e-5
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.9994987)) < 1e-5
def test_variance_learned_range(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config(variance_type="learned_range")
scheduler = scheduler_class(**scheduler_config)
predicted_variance = 0.5
assert scheduler._get_variance(1, predicted_variance=predicted_variance) - -10.1712790 < 1e-5
assert scheduler._get_variance(487, predicted_variance=predicted_variance) - -5.7998052 < 1e-5
assert scheduler._get_variance(999, predicted_variance=predicted_variance) - -0.0010011 < 1e-5
def test_full_loop(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
timesteps = scheduler.timesteps
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
for i, t in enumerate(timesteps):
# 1. predict noise residual
residual = model(sample, t)
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 252.2682495) < 1e-2
assert abs(result_mean.item() - 0.3284743) < 1e-3
def test_full_loop_skip_timesteps(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)
scheduler.set_timesteps(25)
timesteps = scheduler.timesteps
model = self.dummy_model()
sample = self.dummy_sample_deter
generator = torch.manual_seed(0)
for i, t in enumerate(timesteps):
# 1. predict noise residual
residual = model(sample, t)
if i + 1 == timesteps.shape[0]:
prev_timestep = None
else:
prev_timestep = timesteps[i + 1]
# 2. predict previous mean of sample x_t-1
pred_prev_sample = scheduler.step(
residual, t, sample, prev_timestep=prev_timestep, generator=generator
).prev_sample
sample = pred_prev_sample
result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))
assert abs(result_sum.item() - 258.2044983) < 1e-2
assert abs(result_mean.item() - 0.3362038) < 1e-3
def test_trained_betas(self):
pass
def test_add_noise_device(self):
pass