mirror of
https://github.com/huggingface/diffusers.git
synced 2026-01-27 17:22:53 +03:00
kakaobrain unCLIP (#1428)
* [wip] attention block updates * [wip] unCLIP unet decoder and super res * [wip] unCLIP prior transformer * [wip] scheduler changes * [wip] text proj utility class * [wip] UnCLIPPipeline * [wip] kakaobrain unCLIP convert script * [unCLIP pipeline] fixes re: @patrickvonplaten remove callbacks move denoising loops into call function * UNCLIPScheduler re: @patrickvonplaten Revert changes to DDPMScheduler. Make UNCLIPScheduler, a modified DDPM scheduler with changes to support karlo * mask -> attention_mask re: @patrickvonplaten * [DDPMScheduler] remove leftover change * [docs] PriorTransformer * [docs] UNet2DConditionModel and UNet2DModel * [nit] UNCLIPScheduler -> UnCLIPScheduler matches existing unclip naming better * [docs] SchedulingUnCLIP * [docs] UnCLIPTextProjModel * refactor * finish licenses * rename all to attention_mask and prep in models * more renaming * don't expose unused configs * final renaming fixes * remove x attn mask when not necessary * configure kakao script to use new class embedding config * fix copies * [tests] UnCLIPScheduler * finish x attn * finish * remove more * rename condition blocks * clean more * Apply suggestions from code review * up * fix * [tests] UnCLIPPipelineFastTests * remove unused imports * [tests] UnCLIPPipelineIntegrationTests * correct * make style Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
1159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Normal file
1159
scripts/convert_kakao_brain_unclip_to_diffusers.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -25,7 +25,15 @@ except OptionalDependencyNotAvailable:
|
||||
from .utils.dummy_pt_objects import * # noqa F403
|
||||
else:
|
||||
from .modeling_utils import ModelMixin
|
||||
from .models import AutoencoderKL, Transformer2DModel, UNet1DModel, UNet2DConditionModel, UNet2DModel, VQModel
|
||||
from .models import (
|
||||
AutoencoderKL,
|
||||
PriorTransformer,
|
||||
Transformer2DModel,
|
||||
UNet1DModel,
|
||||
UNet2DConditionModel,
|
||||
UNet2DModel,
|
||||
VQModel,
|
||||
)
|
||||
from .optimization import (
|
||||
get_constant_schedule,
|
||||
get_constant_schedule_with_warmup,
|
||||
@@ -63,6 +71,7 @@ else:
|
||||
RePaintScheduler,
|
||||
SchedulerMixin,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
VQDiffusionScheduler,
|
||||
)
|
||||
from .training_utils import EMAModel
|
||||
@@ -96,6 +105,7 @@ else:
|
||||
StableDiffusionPipeline,
|
||||
StableDiffusionPipelineSafe,
|
||||
StableDiffusionUpscalePipeline,
|
||||
UnCLIPPipeline,
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
VersatileDiffusionPipeline,
|
||||
|
||||
@@ -17,6 +17,7 @@ from ..utils import is_flax_available, is_torch_available
|
||||
|
||||
if is_torch_available():
|
||||
from .attention import Transformer2DModel
|
||||
from .prior_transformer import PriorTransformer
|
||||
from .unet_1d import UNet1DModel
|
||||
from .unet_2d import UNet2DModel
|
||||
from .unet_2d_condition import UNet2DConditionModel
|
||||
|
||||
@@ -66,8 +66,8 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
in_channels (`int`, *optional*):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
@@ -181,7 +181,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
@@ -213,7 +213,7 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||
|
||||
# 2. Blocks
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, context=encoder_hidden_states, timestep=timestep)
|
||||
hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
|
||||
|
||||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
@@ -260,6 +260,8 @@ class AttentionBlock(nn.Module):
|
||||
eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
|
||||
"""
|
||||
|
||||
# IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
@@ -369,6 +371,7 @@ class AttentionBlock(nn.Module):
|
||||
|
||||
# compute next hidden_states
|
||||
hidden_states = self.proj_attn(hidden_states)
|
||||
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
|
||||
|
||||
# res connect and rescale
|
||||
@@ -385,7 +388,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
num_attention_heads (`int`): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The size of the context vector for cross attention.
|
||||
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
|
||||
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
|
||||
num_embeds_ada_norm (:
|
||||
obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
|
||||
@@ -432,7 +435,7 @@ class BasicTransformerBlock(nn.Module):
|
||||
dropout=dropout,
|
||||
bias=attention_bias,
|
||||
upcast_attention=upcast_attention,
|
||||
) # is self-attn if context is none
|
||||
) # is self-attn if encoder_hidden_states is none
|
||||
else:
|
||||
self.attn2 = None
|
||||
|
||||
@@ -472,23 +475,30 @@ class BasicTransformerBlock(nn.Module):
|
||||
self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def forward(self, hidden_states, context=None, timestep=None):
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
|
||||
# 1. Self-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
|
||||
)
|
||||
|
||||
if self.only_cross_attention:
|
||||
hidden_states = self.attn1(norm_hidden_states, context) + hidden_states
|
||||
hidden_states = (
|
||||
self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = self.attn1(norm_hidden_states) + hidden_states
|
||||
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
|
||||
|
||||
if self.attn2 is not None:
|
||||
# 2. Cross-Attention
|
||||
norm_hidden_states = (
|
||||
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
|
||||
)
|
||||
hidden_states = self.attn2(norm_hidden_states, context=context) + hidden_states
|
||||
hidden_states = (
|
||||
self.attn2(
|
||||
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
+ hidden_states
|
||||
)
|
||||
|
||||
# 3. Feed-forward
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
@@ -503,7 +513,7 @@ class CrossAttention(nn.Module):
|
||||
Parameters:
|
||||
query_dim (`int`): The number of channels in the query.
|
||||
cross_attention_dim (`int`, *optional*):
|
||||
The number of channels in the context. If not given, defaults to `query_dim`.
|
||||
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
||||
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
||||
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
@@ -520,13 +530,18 @@ class CrossAttention(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
bias=False,
|
||||
upcast_attention: bool = False,
|
||||
upcast_softmax: bool = False,
|
||||
added_kv_proj_dim: Optional[int] = None,
|
||||
norm_num_groups: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
inner_dim = dim_head * heads
|
||||
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
||||
self.upcast_attention = upcast_attention
|
||||
self.upcast_softmax = upcast_softmax
|
||||
|
||||
self.scale = dim_head**-0.5
|
||||
|
||||
self.heads = heads
|
||||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
@@ -534,11 +549,21 @@ class CrossAttention(nn.Module):
|
||||
self.sliceable_head_dim = heads
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self.added_kv_proj_dim = added_kv_proj_dim
|
||||
|
||||
if norm_num_groups is not None:
|
||||
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
||||
else:
|
||||
self.group_norm = None
|
||||
|
||||
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
||||
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
||||
|
||||
if self.added_kv_proj_dim is not None:
|
||||
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
@@ -563,40 +588,58 @@ class CrossAttention(nn.Module):
|
||||
|
||||
self._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
encoder_hidden_states = encoder_hidden_states
|
||||
|
||||
if self.group_norm is not None:
|
||||
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = self.to_q(hidden_states)
|
||||
context = context if context is not None else hidden_states
|
||||
key = self.to_k(context)
|
||||
value = self.to_v(context)
|
||||
|
||||
dim = query.shape[-1]
|
||||
|
||||
query = self.reshape_heads_to_batch_dim(query)
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# TODO(PVP) - mask is currently never used. Remember to re-implement when used
|
||||
if self.added_kv_proj_dim is not None:
|
||||
key = self.to_k(hidden_states)
|
||||
value = self.to_v(hidden_states)
|
||||
encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
|
||||
encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
|
||||
encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
|
||||
|
||||
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
||||
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
||||
else:
|
||||
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
||||
key = self.to_k(encoder_hidden_states)
|
||||
value = self.to_v(encoder_hidden_states)
|
||||
|
||||
key = self.reshape_heads_to_batch_dim(key)
|
||||
value = self.reshape_heads_to_batch_dim(value)
|
||||
|
||||
# attention, what we cannot get enough of
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value)
|
||||
hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
|
||||
# Some versions of xformers return output in fp32, cast it back to the dtype of the input
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
else:
|
||||
if self._slice_size is None or query.shape[0] // self._slice_size == 1:
|
||||
hidden_states = self._attention(query, key, value)
|
||||
hidden_states = self._attention(query, key, value, attention_mask)
|
||||
else:
|
||||
hidden_states = self._sliced_attention(query, key, value, sequence_length, dim)
|
||||
|
||||
# linear proj
|
||||
hidden_states = self.to_out[0](hidden_states)
|
||||
|
||||
# dropout
|
||||
hidden_states = self.to_out[1](hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _attention(self, query, key, value):
|
||||
def _attention(self, query, key, value, attention_mask=None):
|
||||
if self.upcast_attention:
|
||||
query = query.float()
|
||||
key = key.float()
|
||||
@@ -608,6 +651,18 @@ class CrossAttention(nn.Module):
|
||||
beta=0,
|
||||
alpha=self.scale,
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
if attention_mask.shape != attention_scores.shape:
|
||||
target_length = query.shape[1]
|
||||
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
||||
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
|
||||
|
||||
attention_scores = attention_scores + attention_mask
|
||||
|
||||
if self.upcast_softmax:
|
||||
attention_scores = attention_scores.float()
|
||||
|
||||
attention_probs = attention_scores.softmax(dim=-1)
|
||||
|
||||
# cast back to the original dtype
|
||||
@@ -656,7 +711,8 @@ class CrossAttention(nn.Module):
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
def _memory_efficient_attention_xformers(self, query, key, value):
|
||||
def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
|
||||
# TODO attention_mask
|
||||
query = query.contiguous()
|
||||
key = key.contiguous()
|
||||
value = value.contiguous()
|
||||
@@ -802,7 +858,7 @@ class DualTransformer2DModel(nn.Module):
|
||||
Pass if the input is continuous. The number of channels in the input and output.
|
||||
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
|
||||
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of context dimensions to use.
|
||||
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
|
||||
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
|
||||
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
|
||||
`ImagePositionalEmbeddings`.
|
||||
@@ -867,17 +923,21 @@ class DualTransformer2DModel(nn.Module):
|
||||
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
|
||||
self.transformer_index_for_condition = [1, 0]
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states, timestep=None, return_dict: bool = True):
|
||||
def forward(
|
||||
self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
|
||||
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
|
||||
hidden_states
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, context dim)`, *optional*):
|
||||
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
|
||||
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
|
||||
self-attention.
|
||||
timestep ( `torch.long`, *optional*):
|
||||
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
|
||||
attention_mask (`torch.FloatTensor`, *optional*):
|
||||
Optional attention mask to be applied in CrossAttention
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
|
||||
|
||||
@@ -894,9 +954,13 @@ class DualTransformer2DModel(nn.Module):
|
||||
# for each of the two transformers, pass the corresponding condition tokens
|
||||
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
|
||||
transformer_index = self.transformer_index_for_condition[i]
|
||||
encoded_state = self.transformers[transformer_index](input_states, condition_state, timestep, return_dict)[
|
||||
0
|
||||
]
|
||||
encoded_state = self.transformers[transformer_index](
|
||||
input_states,
|
||||
encoder_hidden_states=condition_state,
|
||||
timestep=timestep,
|
||||
attention_mask=attention_mask,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
encoded_states.append(encoded_state - input_states)
|
||||
tokens_start += self.condition_lengths[i]
|
||||
|
||||
|
||||
194
src/diffusers/models/prior_transformer.py
Normal file
194
src/diffusers/models/prior_transformer.py
Normal file
@@ -0,0 +1,194 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .attention import BasicTransformerBlock
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
|
||||
|
||||
@dataclass
|
||||
class PriorTransformerOutput(BaseOutput):
|
||||
"""
|
||||
Args:
|
||||
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
|
||||
"""
|
||||
|
||||
predicted_image_embedding: torch.FloatTensor
|
||||
|
||||
|
||||
class PriorTransformer(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
The prior transformer from unCLIP is used to predict CLIP image embeddings from CLIP text embeddings. Note that the
|
||||
transformer predicts the image embeddings through a denoising diffusion process.
|
||||
|
||||
This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
|
||||
implements for all the models (such as downloading or saving, etc.)
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2204.06125
|
||||
|
||||
Parameters:
|
||||
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
|
||||
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
|
||||
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
|
||||
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the CLIP embeddings. Note that CLIP
|
||||
image embeddings and text embeddings are both the same dimension.
|
||||
num_embeddings (`int`, *optional*, defaults to 77): The max number of clip embeddings allowed. I.e. the
|
||||
length of the prompt after it has been tokenized.
|
||||
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
|
||||
projected hidden_states. The actual length of the used hidden_states is `num_embeddings +
|
||||
additional_embeddings`.
|
||||
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
||||
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_attention_heads: int = 32,
|
||||
attention_head_dim: int = 64,
|
||||
num_layers: int = 20,
|
||||
embedding_dim: int = 768,
|
||||
num_embeddings=77,
|
||||
additional_embeddings=4,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.attention_head_dim = attention_head_dim
|
||||
inner_dim = num_attention_heads * attention_head_dim
|
||||
self.additional_embeddings = additional_embeddings
|
||||
|
||||
self.time_proj = Timesteps(inner_dim, True, 0)
|
||||
self.time_embedding = TimestepEmbedding(inner_dim, inner_dim)
|
||||
|
||||
self.proj_in = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.embedding_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
|
||||
|
||||
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
|
||||
|
||||
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
inner_dim,
|
||||
num_attention_heads,
|
||||
attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn="gelu",
|
||||
attention_bias=True,
|
||||
)
|
||||
for d in range(num_layers)
|
||||
]
|
||||
)
|
||||
|
||||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.proj_to_clip_embeddings = nn.Linear(inner_dim, embedding_dim)
|
||||
|
||||
causal_attention_mask = torch.full(
|
||||
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], float("-inf")
|
||||
)
|
||||
causal_attention_mask.triu_(1)
|
||||
causal_attention_mask = causal_attention_mask[None, ...]
|
||||
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
|
||||
|
||||
self.clip_mean = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
self.clip_std = nn.Parameter(torch.zeros(1, embedding_dim))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
proj_embedding: torch.FloatTensor,
|
||||
encoder_hidden_states: torch.FloatTensor,
|
||||
attention_mask: Optional[torch.BoolTensor] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
x_t, the currently predicted image embeddings.
|
||||
timestep (`torch.long`):
|
||||
Current denoising step.
|
||||
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
|
||||
Projected embedding vector the denoising process is conditioned on.
|
||||
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
|
||||
Hidden states of the text embeddings the denoising process is conditioned on.
|
||||
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
|
||||
Text mask for the text embeddings.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`models.prior_transformer.PriorTransformerOutput`] instead of a plain
|
||||
tuple.
|
||||
|
||||
Returns:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
|
||||
[`~models.prior_transformer.PriorTransformerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
"""
|
||||
batch_size = hidden_states.shape[0]
|
||||
|
||||
timesteps = timestep
|
||||
if not torch.is_tensor(timesteps):
|
||||
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
|
||||
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
|
||||
timesteps = timesteps[None].to(hidden_states.device)
|
||||
|
||||
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
||||
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
|
||||
|
||||
timesteps_projected = self.time_proj(timesteps)
|
||||
|
||||
# timesteps does not contain any weights and will always return f32 tensors
|
||||
# but time_embedding might be fp16, so we need to cast here.
|
||||
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
|
||||
time_embeddings = self.time_embedding(timesteps_projected)
|
||||
|
||||
proj_embeddings = self.embedding_proj(proj_embedding)
|
||||
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
|
||||
hidden_states = self.proj_in(hidden_states)
|
||||
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
|
||||
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
|
||||
|
||||
hidden_states = torch.cat(
|
||||
[
|
||||
encoder_hidden_states,
|
||||
proj_embeddings[:, None, :],
|
||||
time_embeddings[:, None, :],
|
||||
hidden_states[:, None, :],
|
||||
prd_embedding,
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + positional_embeddings
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
|
||||
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
|
||||
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
|
||||
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
hidden_states = block(hidden_states, attention_mask=attention_mask)
|
||||
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = hidden_states[:, -1]
|
||||
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
|
||||
|
||||
if not return_dict:
|
||||
return (predicted_image_embedding,)
|
||||
|
||||
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
|
||||
|
||||
def post_process_latents(self, prior_latents):
|
||||
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
|
||||
return prior_latents
|
||||
@@ -405,7 +405,14 @@ class ResnetBlock2D(nn.Module):
|
||||
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
|
||||
if self.time_embedding_norm == "default":
|
||||
time_emb_proj_out_channels = out_channels
|
||||
elif self.time_embedding_norm == "scale_shift":
|
||||
time_emb_proj_out_channels = out_channels * 2
|
||||
else:
|
||||
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
|
||||
|
||||
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
|
||||
else:
|
||||
self.time_emb_proj = None
|
||||
|
||||
@@ -465,9 +472,16 @@ class ResnetBlock2D(nn.Module):
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None]
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "default":
|
||||
hidden_states = hidden_states + temb
|
||||
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
|
||||
if temb is not None and self.time_embedding_norm == "scale_shift":
|
||||
scale, shift = torch.chunk(temb, 2, dim=1)
|
||||
hidden_states = hidden_states * (1 + scale) + shift
|
||||
|
||||
hidden_states = self.nonlinearity(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
|
||||
@@ -55,6 +55,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`): Tuple of downsample block
|
||||
types.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
|
||||
The mid block type. Choose from `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to :
|
||||
obj:`("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`): Tuple of upsample block types.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to :
|
||||
@@ -66,6 +68,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
|
||||
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for the normalization.
|
||||
norm_eps (`float`, *optional*, defaults to `1e-5`): The epsilon for the normalization.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
@@ -88,6 +92,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: int = 8,
|
||||
norm_num_groups: int = 32,
|
||||
norm_eps: float = 1e-5,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
add_attention: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -130,6 +136,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
@@ -140,9 +147,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_groups=norm_num_groups,
|
||||
add_attention=add_attention,
|
||||
)
|
||||
|
||||
# up
|
||||
@@ -167,6 +175,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
||||
resnet_act_fn=act_fn,
|
||||
resnet_groups=norm_num_groups,
|
||||
attn_num_head_channels=attention_head_dim,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
|
||||
@@ -15,7 +15,7 @@ import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from .attention import AttentionBlock, DualTransformer2DModel, Transformer2DModel
|
||||
from .attention import AttentionBlock, CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from .resnet import Downsample2D, FirDownsample2D, FirUpsample2D, ResnetBlock2D, Upsample2D
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ def get_down_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
|
||||
if down_block_type == "DownBlock2D":
|
||||
@@ -49,6 +50,19 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "ResnetDownsampleBlock2D":
|
||||
return ResnetDownsampleBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnDownBlock2D":
|
||||
return AttnDownBlock2D(
|
||||
@@ -62,6 +76,7 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "CrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -82,6 +97,23 @@ def get_down_block(
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "SimpleCrossAttnDownBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
|
||||
return SimpleCrossAttnDownBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
add_downsample=add_downsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "SkipDownBlock2D":
|
||||
return SkipDownBlock2D(
|
||||
@@ -93,6 +125,7 @@ def get_down_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnSkipDownBlock2D":
|
||||
return AttnSkipDownBlock2D(
|
||||
@@ -105,6 +138,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "DownEncoderBlock2D":
|
||||
return DownEncoderBlock2D(
|
||||
@@ -116,6 +150,7 @@ def get_down_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif down_block_type == "AttnDownEncoderBlock2D":
|
||||
return AttnDownEncoderBlock2D(
|
||||
@@ -128,6 +163,7 @@ def get_down_block(
|
||||
resnet_groups=resnet_groups,
|
||||
downsample_padding=downsample_padding,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{down_block_type} does not exist.")
|
||||
|
||||
@@ -149,6 +185,7 @@ def get_up_block(
|
||||
use_linear_projection=False,
|
||||
only_cross_attention=False,
|
||||
upcast_attention=False,
|
||||
resnet_time_scale_shift="default",
|
||||
):
|
||||
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
|
||||
if up_block_type == "UpBlock2D":
|
||||
@@ -162,6 +199,20 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "ResnetUpsampleBlock2D":
|
||||
return ResnetUpsampleBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "CrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
@@ -182,6 +233,24 @@ def get_up_block(
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention,
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "SimpleCrossAttnUpBlock2D":
|
||||
if cross_attention_dim is None:
|
||||
raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
|
||||
return SimpleCrossAttnUpBlock2D(
|
||||
num_layers=num_layers,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
prev_output_channel=prev_output_channel,
|
||||
temb_channels=temb_channels,
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnUpBlock2D":
|
||||
return AttnUpBlock2D(
|
||||
@@ -195,6 +264,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "SkipUpBlock2D":
|
||||
return SkipUpBlock2D(
|
||||
@@ -206,6 +276,7 @@ def get_up_block(
|
||||
add_upsample=add_upsample,
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnSkipUpBlock2D":
|
||||
return AttnSkipUpBlock2D(
|
||||
@@ -218,6 +289,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "UpDecoderBlock2D":
|
||||
return UpDecoderBlock2D(
|
||||
@@ -228,6 +300,7 @@ def get_up_block(
|
||||
resnet_eps=resnet_eps,
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
elif up_block_type == "AttnUpDecoderBlock2D":
|
||||
return AttnUpDecoderBlock2D(
|
||||
@@ -239,6 +312,7 @@ def get_up_block(
|
||||
resnet_act_fn=resnet_act_fn,
|
||||
resnet_groups=resnet_groups,
|
||||
attn_num_head_channels=attn_num_head_channels,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
raise ValueError(f"{up_block_type} does not exist.")
|
||||
|
||||
@@ -255,14 +329,13 @@ class UNetMidBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
add_attention: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.attention_type = attention_type
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
self.add_attention = add_attention
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
@@ -282,15 +355,19 @@ class UNetMidBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
if self.add_attention:
|
||||
attentions.append(
|
||||
AttentionBlock(
|
||||
in_channels,
|
||||
num_head_channels=attn_num_head_channels,
|
||||
rescale_output_factor=output_scale_factor,
|
||||
eps=resnet_eps,
|
||||
norm_num_groups=resnet_groups,
|
||||
)
|
||||
)
|
||||
)
|
||||
else:
|
||||
attentions.append(None)
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
@@ -309,13 +386,11 @@ class UNetMidBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_states=None):
|
||||
def forward(self, hidden_states, temb=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
if self.attention_type == "default":
|
||||
if attn is not None:
|
||||
hidden_states = attn(hidden_states)
|
||||
else:
|
||||
hidden_states = attn(hidden_states, encoder_states)
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -334,7 +409,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
@@ -344,7 +418,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
@@ -408,10 +481,121 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class UNetMidBlock2DSimpleCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
self.num_heads = in_channels // self.attn_num_head_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=in_channels,
|
||||
cross_attention_dim=in_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
@@ -431,7 +615,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -440,8 +623,6 @@ class AttnDownBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
@@ -514,7 +695,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -528,7 +708,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -588,7 +767,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -605,7 +785,9 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -847,7 +1029,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -856,8 +1037,6 @@ class AttnSkipDownBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList([])
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
self.resnets.append(
|
||||
@@ -1006,6 +1185,205 @@ class SkipDownBlock2D(nn.Module):
|
||||
return hidden_states, output_states, skip_sample
|
||||
|
||||
|
||||
class ResnetDownsampleBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class SimpleCrossAttnDownBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_downsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
self.num_heads = out_channels // self.attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
in_channels = in_channels if i == 0 else out_channels
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_downsample:
|
||||
self.downsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
down=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
for downsampler in self.downsamplers:
|
||||
hidden_states = downsampler(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
return hidden_states, output_states
|
||||
|
||||
|
||||
class AttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1020,7 +1398,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attention_type="default",
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
@@ -1029,8 +1406,6 @@ class AttnUpBlock2D(nn.Module):
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -1100,7 +1475,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
@@ -1113,7 +1487,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -1176,7 +1549,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1196,7 +1571,9 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1418,7 +1795,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=np.sqrt(2.0),
|
||||
upsample_padding=1,
|
||||
add_upsample=True,
|
||||
@@ -1427,8 +1803,6 @@ class AttnSkipUpBlock2D(nn.Module):
|
||||
self.attentions = nn.ModuleList([])
|
||||
self.resnets = nn.ModuleList([])
|
||||
|
||||
self.attention_type = attention_type
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
@@ -1608,3 +1982,213 @@ class SkipUpBlock2D(nn.Module):
|
||||
hidden_states = self.resnet_up(hidden_states, temb)
|
||||
|
||||
return hidden_states, skip_sample
|
||||
|
||||
|
||||
class ResnetUpsampleBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
prev_output_channel: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class SimpleCrossAttnUpBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
prev_output_channel: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
resnets = []
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
self.num_heads = out_channels // self.attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
|
||||
resnet_in_channels = prev_output_channel if i == 0 else out_channels
|
||||
|
||||
resnets.append(
|
||||
ResnetBlock2D(
|
||||
in_channels=resnet_in_channels + res_skip_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=out_channels,
|
||||
cross_attention_dim=out_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
if add_upsample:
|
||||
self.upsamplers = nn.ModuleList(
|
||||
[
|
||||
ResnetBlock2D(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
up=True,
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# resnet
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
hidden_states = upsampler(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -27,6 +27,7 @@ from .unet_2d_blocks import (
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UNetMidBlock2DSimpleCrossAttn,
|
||||
UpBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
@@ -66,6 +67,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlock2DCrossAttn` or `UNetMidBlock2DSimpleCrossAttn`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
@@ -78,6 +81,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlock2D`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -97,6 +104,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlock2DCrossAttn",
|
||||
up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
|
||||
only_cross_attention: Union[bool, Tuple[bool]] = False,
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
||||
@@ -110,8 +118,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -128,8 +138,14 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -165,24 +181,40 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
if mid_block_type == "UNetMidBlock2DCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
@@ -223,6 +255,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -307,6 +340,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -336,6 +370,11 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -365,9 +404,13 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
@@ -382,6 +425,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -389,7 +433,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
@@ -410,6 +456,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
|
||||
@@ -53,6 +53,7 @@ else:
|
||||
StableDiffusionUpscalePipeline,
|
||||
)
|
||||
from .stable_diffusion_safe import StableDiffusionPipelineSafe
|
||||
from .unclip import UnCLIPPipeline
|
||||
from .versatile_diffusion import (
|
||||
VersatileDiffusionDualGuidedPipeline,
|
||||
VersatileDiffusionImageVariationPipeline,
|
||||
|
||||
6
src/diffusers/pipelines/unclip/__init__.py
Normal file
6
src/diffusers/pipelines/unclip/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from ...utils import is_torch_available, is_transformers_available
|
||||
|
||||
|
||||
if is_transformers_available() and is_torch_available():
|
||||
from .pipeline_unclip import UnCLIPPipeline
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
370
src/diffusers/pipelines/unclip/pipeline_unclip.py
Normal file
370
src/diffusers/pipelines/unclip/pipeline_unclip.py
Normal file
@@ -0,0 +1,370 @@
|
||||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from diffusers import PriorTransformer, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
||||
from diffusers.schedulers import UnCLIPScheduler
|
||||
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
from ...utils import logging
|
||||
from .text_proj import UnCLIPTextProjModel
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
||||
|
||||
|
||||
class UnCLIPPipeline(DiffusionPipeline):
|
||||
prior: PriorTransformer
|
||||
decoder: UNet2DConditionModel
|
||||
text_proj: UnCLIPTextProjModel
|
||||
text_encoder: CLIPTextModelWithProjection
|
||||
tokenizer: CLIPTokenizer
|
||||
super_res_first: UNet2DModel
|
||||
super_res_last: UNet2DModel
|
||||
|
||||
prior_scheduler: UnCLIPScheduler
|
||||
decoder_scheduler: UnCLIPScheduler
|
||||
super_res_scheduler: UnCLIPScheduler
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
prior: PriorTransformer,
|
||||
decoder: UNet2DConditionModel,
|
||||
text_encoder: CLIPTextModelWithProjection,
|
||||
tokenizer: CLIPTokenizer,
|
||||
text_proj: UnCLIPTextProjModel,
|
||||
super_res_first: UNet2DModel,
|
||||
super_res_last: UNet2DModel,
|
||||
prior_scheduler: UnCLIPScheduler,
|
||||
decoder_scheduler: UnCLIPScheduler,
|
||||
super_res_scheduler: UnCLIPScheduler,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.register_modules(
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
text_proj=text_proj,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
|
||||
def prepare_latents(self, shape, dtype, device, generator, latents, scheduler):
|
||||
if latents is None:
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
|
||||
else:
|
||||
latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
|
||||
else:
|
||||
if latents.shape != shape:
|
||||
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
||||
latents = latents.to(device)
|
||||
|
||||
latents = latents * scheduler.init_noise_sigma
|
||||
return latents
|
||||
|
||||
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance):
|
||||
batch_size = len(prompt) if isinstance(prompt, list) else 1
|
||||
|
||||
# get prompt text embeddings
|
||||
text_inputs = self.tokenizer(
|
||||
prompt,
|
||||
padding="max_length",
|
||||
max_length=self.tokenizer.model_max_length,
|
||||
return_tensors="pt",
|
||||
)
|
||||
text_input_ids = text_inputs.input_ids
|
||||
text_mask = text_inputs.attention_mask.bool().to(self.device)
|
||||
|
||||
if text_input_ids.shape[-1] > self.tokenizer.model_max_length:
|
||||
removed_text = self.tokenizer.batch_decode(text_input_ids[:, self.tokenizer.model_max_length :])
|
||||
logger.warning(
|
||||
"The following part of your input was truncated because CLIP can only handle sequences up to"
|
||||
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
|
||||
)
|
||||
text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
|
||||
|
||||
text_encoder_output = self.text_encoder(text_input_ids.to(self.device))
|
||||
|
||||
text_embeddings = text_encoder_output.text_embeds
|
||||
text_encoder_hidden_states = text_encoder_output.last_hidden_state
|
||||
|
||||
text_embeddings = text_embeddings.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
uncond_tokens = [""] * batch_size
|
||||
|
||||
max_length = text_input_ids.shape[-1]
|
||||
uncond_input = self.tokenizer(
|
||||
uncond_tokens,
|
||||
padding="max_length",
|
||||
max_length=max_length,
|
||||
truncation=True,
|
||||
return_tensors="pt",
|
||||
)
|
||||
uncond_text_mask = uncond_input.attention_mask.bool().to(self.device)
|
||||
uncond_embeddings_text_encoder_output = self.text_encoder(uncond_input.input_ids.to(self.device))
|
||||
|
||||
uncond_embeddings = uncond_embeddings_text_encoder_output.text_embeds
|
||||
uncond_text_encoder_hidden_states = uncond_embeddings_text_encoder_output.last_hidden_state
|
||||
|
||||
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
||||
|
||||
seq_len = uncond_embeddings.shape[1]
|
||||
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt)
|
||||
uncond_embeddings = uncond_embeddings.view(batch_size * num_images_per_prompt, seq_len)
|
||||
|
||||
seq_len = uncond_text_encoder_hidden_states.shape[1]
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.repeat(1, num_images_per_prompt, 1)
|
||||
uncond_text_encoder_hidden_states = uncond_text_encoder_hidden_states.view(
|
||||
batch_size * num_images_per_prompt, seq_len, -1
|
||||
)
|
||||
|
||||
# done duplicates
|
||||
|
||||
# For classifier free guidance, we need to do two forward passes.
|
||||
# Here we concatenate the unconditional and text embeddings into a single batch
|
||||
# to avoid doing two forward passes
|
||||
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
||||
text_encoder_hidden_states = torch.cat([uncond_text_encoder_hidden_states, text_encoder_hidden_states])
|
||||
|
||||
text_mask = torch.cat([uncond_text_mask, text_mask])
|
||||
|
||||
return text_embeddings, text_encoder_hidden_states, text_mask
|
||||
|
||||
@torch.no_grad()
|
||||
def __call__(
|
||||
self,
|
||||
prompt: Union[str, List[str]],
|
||||
num_images_per_prompt: int = 1,
|
||||
prior_num_inference_steps: int = 25,
|
||||
decoder_num_inference_steps: int = 25,
|
||||
super_res_num_inference_steps: int = 7,
|
||||
generator: Optional[torch.Generator] = None,
|
||||
prior_latents: Optional[torch.FloatTensor] = None,
|
||||
decoder_latents: Optional[torch.FloatTensor] = None,
|
||||
super_res_latents: Optional[torch.FloatTensor] = None,
|
||||
prior_guidance_scale: float = 4.0,
|
||||
decoder_guidance_scale: float = 8.0,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
):
|
||||
if isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
|
||||
|
||||
batch_size = batch_size * num_images_per_prompt
|
||||
|
||||
do_classifier_free_guidance = prior_guidance_scale > 1.0 or decoder_guidance_scale > 1.0
|
||||
|
||||
text_embeddings, text_encoder_hidden_states, text_mask = self._encode_prompt(
|
||||
prompt, num_images_per_prompt, do_classifier_free_guidance
|
||||
)
|
||||
|
||||
# prior
|
||||
|
||||
self.prior_scheduler.set_timesteps(prior_num_inference_steps, device=self.device)
|
||||
prior_timesteps_tensor = self.prior_scheduler.timesteps
|
||||
|
||||
embedding_dim = self.prior.config.embedding_dim
|
||||
prior_latents = self.prepare_latents(
|
||||
(batch_size, embedding_dim),
|
||||
text_embeddings.dtype,
|
||||
self.device,
|
||||
generator,
|
||||
prior_latents,
|
||||
self.prior_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(prior_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([prior_latents] * 2) if do_classifier_free_guidance else prior_latents
|
||||
|
||||
predicted_image_embedding = self.prior(
|
||||
latent_model_input,
|
||||
timestep=t,
|
||||
proj_embedding=text_embeddings,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
attention_mask=text_mask,
|
||||
).predicted_image_embedding
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
predicted_image_embedding_uncond, predicted_image_embedding_text = predicted_image_embedding.chunk(2)
|
||||
predicted_image_embedding = predicted_image_embedding_uncond + prior_guidance_scale * (
|
||||
predicted_image_embedding_text - predicted_image_embedding_uncond
|
||||
)
|
||||
|
||||
if i + 1 == prior_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = prior_timesteps_tensor[i + 1]
|
||||
|
||||
prior_latents = self.prior_scheduler.step(
|
||||
predicted_image_embedding,
|
||||
timestep=t,
|
||||
sample=prior_latents,
|
||||
generator=generator,
|
||||
prev_timestep=prev_timestep,
|
||||
).prev_sample
|
||||
|
||||
prior_latents = self.prior.post_process_latents(prior_latents)
|
||||
|
||||
image_embeddings = prior_latents
|
||||
|
||||
# done prior
|
||||
|
||||
# decoder
|
||||
|
||||
text_encoder_hidden_states, additive_clip_time_embeddings = self.text_proj(
|
||||
image_embeddings=image_embeddings,
|
||||
text_embeddings=text_embeddings,
|
||||
text_encoder_hidden_states=text_encoder_hidden_states,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
)
|
||||
|
||||
decoder_text_mask = F.pad(text_mask, (self.text_proj.clip_extra_context_tokens, 0), value=1)
|
||||
|
||||
self.decoder_scheduler.set_timesteps(decoder_num_inference_steps, device=self.device)
|
||||
decoder_timesteps_tensor = self.decoder_scheduler.timesteps
|
||||
|
||||
num_channels_latents = self.decoder.in_channels
|
||||
height = self.decoder.sample_size
|
||||
width = self.decoder.sample_size
|
||||
decoder_latents = self.prepare_latents(
|
||||
(batch_size, num_channels_latents, height, width),
|
||||
text_encoder_hidden_states.dtype,
|
||||
self.device,
|
||||
generator,
|
||||
decoder_latents,
|
||||
self.decoder_scheduler,
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(decoder_timesteps_tensor)):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([decoder_latents] * 2) if do_classifier_free_guidance else decoder_latents
|
||||
|
||||
noise_pred = self.decoder(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
encoder_hidden_states=text_encoder_hidden_states,
|
||||
class_labels=additive_clip_time_embeddings,
|
||||
attention_mask=decoder_text_mask,
|
||||
).sample
|
||||
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred_uncond, _ = noise_pred_uncond.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred_text, predicted_variance = noise_pred_text.split(latent_model_input.shape[1], dim=1)
|
||||
noise_pred = noise_pred_uncond + decoder_guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
noise_pred = torch.cat([noise_pred, predicted_variance], dim=1)
|
||||
|
||||
if i + 1 == decoder_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = decoder_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
decoder_latents = self.decoder_scheduler.step(
|
||||
noise_pred, t, decoder_latents, prev_timestep=prev_timestep
|
||||
).prev_sample
|
||||
|
||||
decoder_latents = decoder_latents.clamp(-1, 1)
|
||||
|
||||
image_small = decoder_latents
|
||||
|
||||
# done decoder
|
||||
|
||||
# super res
|
||||
|
||||
self.super_res_scheduler.set_timesteps(super_res_num_inference_steps, device=self.device)
|
||||
super_res_timesteps_tensor = self.super_res_scheduler.timesteps
|
||||
|
||||
channels = self.super_res_first.in_channels // 2
|
||||
height = self.super_res_first.sample_size
|
||||
width = self.super_res_first.sample_size
|
||||
super_res_latents = self.prepare_latents(
|
||||
(batch_size, channels, height, width),
|
||||
image_small.dtype,
|
||||
self.device,
|
||||
generator,
|
||||
super_res_latents,
|
||||
self.super_res_scheduler,
|
||||
)
|
||||
|
||||
interpolate_antialias = {}
|
||||
if "antialias" in inspect.signature(F.interpolate).parameters:
|
||||
interpolate_antialias["antialias"] = True
|
||||
|
||||
image_upscaled = F.interpolate(
|
||||
image_small, size=[height, width], mode="bicubic", align_corners=False, **interpolate_antialias
|
||||
)
|
||||
|
||||
for i, t in enumerate(self.progress_bar(super_res_timesteps_tensor)):
|
||||
# no classifier free guidance
|
||||
|
||||
if i == super_res_timesteps_tensor.shape[0] - 1:
|
||||
unet = self.super_res_last
|
||||
else:
|
||||
unet = self.super_res_first
|
||||
|
||||
latent_model_input = torch.cat([super_res_latents, image_upscaled], dim=1)
|
||||
|
||||
noise_pred = unet(
|
||||
sample=latent_model_input,
|
||||
timestep=t,
|
||||
).sample
|
||||
|
||||
if i + 1 == super_res_timesteps_tensor.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = super_res_timesteps_tensor[i + 1]
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
super_res_latents = self.super_res_scheduler.step(
|
||||
noise_pred, t, super_res_latents, prev_timestep=prev_timestep
|
||||
).prev_sample
|
||||
|
||||
image = super_res_latents
|
||||
|
||||
# done super res
|
||||
|
||||
# post processing
|
||||
|
||||
image = image * 0.5 + 0.5
|
||||
image = image.clamp(0, 1)
|
||||
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
|
||||
if output_type == "pil":
|
||||
image = self.numpy_to_pil(image)
|
||||
|
||||
if not return_dict:
|
||||
return (image,)
|
||||
|
||||
return ImagePipelineOutput(images=image)
|
||||
87
src/diffusers/pipelines/unclip/text_proj.py
Normal file
87
src/diffusers/pipelines/unclip/text_proj.py
Normal file
@@ -0,0 +1,87 @@
|
||||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
from diffusers.modeling_utils import ModelMixin
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
|
||||
|
||||
class UnCLIPTextProjModel(ModelMixin, ConfigMixin):
|
||||
"""
|
||||
Utility class for CLIP embeddings. Used to combine the image and text embeddings into a format usable by the
|
||||
decoder.
|
||||
|
||||
For more details, see the original paper: https://arxiv.org/abs/2204.06125 section 2.1
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
clip_extra_context_tokens: int = 4,
|
||||
clip_embeddings_dim: int = 768,
|
||||
time_embed_dim: int,
|
||||
cross_attention_dim,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.learned_classifier_free_guidance_embeddings = nn.Parameter(torch.zeros(clip_embeddings_dim))
|
||||
|
||||
# parameters for additional clip time embeddings
|
||||
self.embedding_proj = nn.Linear(clip_embeddings_dim, time_embed_dim)
|
||||
self.clip_image_embeddings_project_to_time_embeddings = nn.Linear(clip_embeddings_dim, time_embed_dim)
|
||||
|
||||
# parameters for encoder hidden states
|
||||
self.clip_extra_context_tokens = clip_extra_context_tokens
|
||||
self.clip_extra_context_tokens_proj = nn.Linear(
|
||||
clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim
|
||||
)
|
||||
self.encoder_hidden_states_proj = nn.Linear(clip_embeddings_dim, cross_attention_dim)
|
||||
self.text_encoder_hidden_states_norm = nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
def forward(self, *, image_embeddings, text_embeddings, text_encoder_hidden_states, do_classifier_free_guidance):
|
||||
if do_classifier_free_guidance:
|
||||
# Add the classifier free guidance embeddings to the image embeddings
|
||||
image_embeddings_batch_size = image_embeddings.shape[0]
|
||||
classifier_free_guidance_embeddings = self.learned_classifier_free_guidance_embeddings.unsqueeze(0)
|
||||
classifier_free_guidance_embeddings = classifier_free_guidance_embeddings.expand(
|
||||
image_embeddings_batch_size, -1
|
||||
)
|
||||
image_embeddings = torch.cat([classifier_free_guidance_embeddings, image_embeddings], dim=0)
|
||||
|
||||
# The image embeddings batch size and the text embeddings batch size are equal
|
||||
assert image_embeddings.shape[0] == text_embeddings.shape[0]
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
|
||||
# "Specifically, we modify the architecture described in Nichol et al. (2021) by projecting and
|
||||
# adding CLIP embeddings to the existing timestep embedding, ...
|
||||
time_projected_text_embeddings = self.embedding_proj(text_embeddings)
|
||||
time_projected_image_embeddings = self.clip_image_embeddings_project_to_time_embeddings(image_embeddings)
|
||||
additive_clip_time_embeddings = time_projected_image_embeddings + time_projected_text_embeddings
|
||||
|
||||
# ... and by projecting CLIP embeddings into four
|
||||
# extra tokens of context that are concatenated to the sequence of outputs from the GLIDE text encoder"
|
||||
clip_extra_context_tokens = self.clip_extra_context_tokens_proj(image_embeddings)
|
||||
clip_extra_context_tokens = clip_extra_context_tokens.reshape(batch_size, -1, self.clip_extra_context_tokens)
|
||||
|
||||
text_encoder_hidden_states = self.encoder_hidden_states_proj(text_encoder_hidden_states)
|
||||
text_encoder_hidden_states = self.text_encoder_hidden_states_norm(text_encoder_hidden_states)
|
||||
text_encoder_hidden_states = text_encoder_hidden_states.permute(0, 2, 1)
|
||||
text_encoder_hidden_states = torch.cat([clip_extra_context_tokens, text_encoder_hidden_states], dim=2)
|
||||
|
||||
return text_encoder_hidden_states, additive_clip_time_embeddings
|
||||
@@ -6,8 +6,9 @@ import torch.nn as nn
|
||||
|
||||
from ...configuration_utils import ConfigMixin, register_to_config
|
||||
from ...modeling_utils import ModelMixin
|
||||
from ...models.attention import DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.attention import CrossAttention, DualTransformer2DModel, Transformer2DModel
|
||||
from ...models.embeddings import TimestepEmbedding, Timesteps
|
||||
from ...models.unet_2d_blocks import UNetMidBlock2DSimpleCrossAttn as UNetMidBlockFlatSimpleCrossAttn
|
||||
from ...models.unet_2d_condition import UNet2DConditionOutput
|
||||
from ...utils import logging
|
||||
|
||||
@@ -141,6 +142,8 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
||||
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "CrossAttnDownBlockFlat", "DownBlockFlat")`):
|
||||
The tuple of downsample blocks to use.
|
||||
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlockFlatCrossAttn"`):
|
||||
The mid block type. Choose from `UNetMidBlockFlatCrossAttn` or `UNetMidBlockFlatSimpleCrossAttn`.
|
||||
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat", "CrossAttnUpBlockFlat",)`):
|
||||
The tuple of upsample blocks to use.
|
||||
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
||||
@@ -153,6 +156,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
|
||||
cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
|
||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
|
||||
for resnet blocks, see [`~models.resnet.ResnetBlockFlat`]. Choose from `default` or `scale_shift`.
|
||||
class_embed_type (`str`, *optional*, defaults to None): The type of class embedding to use which is ultimately
|
||||
summed with the time embeddings. Choose from `None`, `"timestep"`, or `"identity"`.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
@@ -172,6 +179,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
"CrossAttnDownBlockFlat",
|
||||
"DownBlockFlat",
|
||||
),
|
||||
mid_block_type: str = "UNetMidBlockFlatCrossAttn",
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlockFlat",
|
||||
"CrossAttnUpBlockFlat",
|
||||
@@ -190,8 +198,10 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
attention_head_dim: Union[int, Tuple[int]] = 8,
|
||||
dual_cross_attention: bool = False,
|
||||
use_linear_projection: bool = False,
|
||||
class_embed_type: Optional[str] = None,
|
||||
num_class_embeds: Optional[int] = None,
|
||||
upcast_attention: bool = False,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -208,8 +218,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
|
||||
# class embedding
|
||||
if num_class_embeds is not None:
|
||||
if class_embed_type is None and num_class_embeds is not None:
|
||||
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
|
||||
elif class_embed_type == "timestep":
|
||||
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
||||
elif class_embed_type == "identity":
|
||||
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
|
||||
else:
|
||||
self.class_embedding = None
|
||||
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_block = None
|
||||
@@ -245,24 +261,40 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
# mid
|
||||
self.mid_block = UNetMidBlockFlatCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift="default",
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
if mid_block_type == "UNetMidBlockFlatCrossAttn":
|
||||
self.mid_block = UNetMidBlockFlatCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
dual_cross_attention=dual_cross_attention,
|
||||
use_linear_projection=use_linear_projection,
|
||||
upcast_attention=upcast_attention,
|
||||
)
|
||||
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
|
||||
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
|
||||
in_channels=block_out_channels[-1],
|
||||
temb_channels=time_embed_dim,
|
||||
resnet_eps=norm_eps,
|
||||
resnet_act_fn=act_fn,
|
||||
output_scale_factor=mid_block_scale_factor,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
attn_num_head_channels=attention_head_dim[-1],
|
||||
resnet_groups=norm_num_groups,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
|
||||
|
||||
# count how many layers upsample the images
|
||||
self.num_upsamplers = 0
|
||||
@@ -303,6 +335,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
use_linear_projection=use_linear_projection,
|
||||
only_cross_attention=only_cross_attention[i],
|
||||
upcast_attention=upcast_attention,
|
||||
resnet_time_scale_shift=resnet_time_scale_shift,
|
||||
)
|
||||
self.up_blocks.append(up_block)
|
||||
prev_output_channel = output_channel
|
||||
@@ -387,6 +420,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
timestep: Union[torch.Tensor, float, int],
|
||||
encoder_hidden_states: torch.Tensor,
|
||||
class_labels: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UNet2DConditionOutput, Tuple]:
|
||||
r"""
|
||||
@@ -416,6 +450,11 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
logger.info("Forward upsample size to force interpolation output size.")
|
||||
forward_upsample_size = True
|
||||
|
||||
# prepare attention_mask
|
||||
if attention_mask is not None:
|
||||
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
||||
attention_mask = attention_mask.unsqueeze(1)
|
||||
|
||||
# 0. center input if necessary
|
||||
if self.config.center_input_sample:
|
||||
sample = 2 * sample - 1.0
|
||||
@@ -445,9 +484,13 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
t_emb = t_emb.to(dtype=self.dtype)
|
||||
emb = self.time_embedding(t_emb)
|
||||
|
||||
if self.config.num_class_embeds is not None:
|
||||
if self.class_embedding is not None:
|
||||
if class_labels is None:
|
||||
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
||||
|
||||
if self.config.class_embed_type == "timestep":
|
||||
class_labels = self.time_proj(class_labels)
|
||||
|
||||
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
|
||||
emb = emb + class_emb
|
||||
|
||||
@@ -462,6 +505,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
@@ -469,7 +513,9 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
down_block_res_samples += res_samples
|
||||
|
||||
# 4. mid
|
||||
sample = self.mid_block(sample, emb, encoder_hidden_states=encoder_hidden_states)
|
||||
sample = self.mid_block(
|
||||
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
|
||||
)
|
||||
|
||||
# 5. up
|
||||
for i, upsample_block in enumerate(self.up_blocks):
|
||||
@@ -490,6 +536,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
||||
res_hidden_states_tuple=res_samples,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
upsample_size=upsample_size,
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
else:
|
||||
sample = upsample_block(
|
||||
@@ -715,7 +762,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
downsample_padding=1,
|
||||
add_downsample=True,
|
||||
@@ -729,7 +775,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -789,7 +834,8 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
@@ -806,7 +852,9 @@ class CrossAttnDownBlockFlat(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -915,7 +963,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
cross_attention_dim=1280,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
add_upsample=True,
|
||||
dual_cross_attention=False,
|
||||
@@ -928,7 +975,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
attentions = []
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
|
||||
for i in range(num_layers):
|
||||
@@ -991,7 +1037,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
upsample_size=None,
|
||||
attention_mask=None,
|
||||
):
|
||||
# TODO(Patrick, William) - attention mask is not used
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
@@ -1011,7 +1059,9 @@ class CrossAttnUpBlockFlat(nn.Module):
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn, return_dict=False), hidden_states, encoder_hidden_states
|
||||
create_custom_forward(attn, return_dict=False),
|
||||
hidden_states,
|
||||
encoder_hidden_states,
|
||||
)[0]
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
@@ -1038,7 +1088,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
attention_type="default",
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
dual_cross_attention=False,
|
||||
@@ -1048,7 +1097,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
self.attention_type = attention_type
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
@@ -1112,10 +1160,122 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
# TODO(Patrick, William) - attention_mask is currently not used. Implement once used
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states).sample
|
||||
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Copied from diffusers.models.unet_2d_blocks.UNetMidBlock2DSimpleCrossAttn with UNetMidBlock2DSimpleCrossAttn->UNetMidBlockFlatCrossAttn, ResnetBlock2D->ResnetBlockFlat
|
||||
class UnCLIPUNetMidBlockFlatCrossAttn(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
temb_channels: int,
|
||||
dropout: float = 0.0,
|
||||
num_layers: int = 1,
|
||||
resnet_eps: float = 1e-6,
|
||||
resnet_time_scale_shift: str = "default",
|
||||
resnet_act_fn: str = "swish",
|
||||
resnet_groups: int = 32,
|
||||
resnet_pre_norm: bool = True,
|
||||
attn_num_head_channels=1,
|
||||
output_scale_factor=1.0,
|
||||
cross_attention_dim=1280,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.has_cross_attention = True
|
||||
|
||||
self.attn_num_head_channels = attn_num_head_channels
|
||||
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
|
||||
|
||||
self.num_heads = in_channels // self.attn_num_head_channels
|
||||
|
||||
# there is always at least one resnet
|
||||
resnets = [
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
]
|
||||
attentions = []
|
||||
|
||||
for _ in range(num_layers):
|
||||
attentions.append(
|
||||
CrossAttention(
|
||||
query_dim=in_channels,
|
||||
cross_attention_dim=in_channels,
|
||||
heads=self.num_heads,
|
||||
dim_head=attn_num_head_channels,
|
||||
added_kv_proj_dim=cross_attention_dim,
|
||||
norm_num_groups=resnet_groups,
|
||||
bias=True,
|
||||
upcast_softmax=True,
|
||||
)
|
||||
)
|
||||
resnets.append(
|
||||
ResnetBlockFlat(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
temb_channels=temb_channels,
|
||||
eps=resnet_eps,
|
||||
groups=resnet_groups,
|
||||
dropout=dropout,
|
||||
time_embedding_norm=resnet_time_scale_shift,
|
||||
non_linearity=resnet_act_fn,
|
||||
output_scale_factor=output_scale_factor,
|
||||
pre_norm=resnet_pre_norm,
|
||||
)
|
||||
)
|
||||
|
||||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
# attn
|
||||
residual = hidden_states
|
||||
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
||||
hidden_states = attn(
|
||||
hidden_states,
|
||||
encoder_hidden_states=encoder_hidden_states.transpose(1, 2),
|
||||
attention_mask=attention_mask,
|
||||
)
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
# resnet
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -37,6 +37,7 @@ else:
|
||||
from .scheduling_repaint import RePaintScheduler
|
||||
from .scheduling_sde_ve import ScoreSdeVeScheduler
|
||||
from .scheduling_sde_vp import ScoreSdeVpScheduler
|
||||
from .scheduling_unclip import UnCLIPScheduler
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
from .scheduling_vq_diffusion import VQDiffusionScheduler
|
||||
|
||||
|
||||
309
src/diffusers/schedulers/scheduling_unclip.py
Normal file
309
src/diffusers/schedulers/scheduling_unclip.py
Normal file
@@ -0,0 +1,309 @@
|
||||
# Copyright 2022 Kakao Brain and The HuggingFace Team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..utils import BaseOutput
|
||||
from .scheduling_utils import SchedulerMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->UnCLIP
|
||||
class UnCLIPSchedulerOutput(BaseOutput):
|
||||
"""
|
||||
Output class for the scheduler's step function output.
|
||||
|
||||
Args:
|
||||
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
|
||||
denoising loop.
|
||||
pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
||||
The predicted denoised sample (x_{0}) based on the model output from the current timestep.
|
||||
`pred_original_sample` can be used to preview progress or for guidance.
|
||||
"""
|
||||
|
||||
prev_sample: torch.FloatTensor
|
||||
pred_original_sample: Optional[torch.FloatTensor] = None
|
||||
|
||||
|
||||
def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999):
|
||||
"""
|
||||
Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
|
||||
(1-beta) over time from t = [0,1].
|
||||
|
||||
Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
|
||||
to that part of the diffusion process.
|
||||
|
||||
|
||||
Args:
|
||||
num_diffusion_timesteps (`int`): the number of betas to produce.
|
||||
max_beta (`float`): the maximum beta to use; use values lower than 1 to
|
||||
prevent singularities.
|
||||
|
||||
Returns:
|
||||
betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
|
||||
"""
|
||||
|
||||
def alpha_bar(time_step):
|
||||
return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
|
||||
|
||||
betas = []
|
||||
for i in range(num_diffusion_timesteps):
|
||||
t1 = i / num_diffusion_timesteps
|
||||
t2 = (i + 1) / num_diffusion_timesteps
|
||||
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
||||
return torch.tensor(betas, dtype=torch.float32)
|
||||
|
||||
|
||||
class UnCLIPScheduler(SchedulerMixin, ConfigMixin):
|
||||
"""
|
||||
This is a modified DDPM Scheduler specifically for the karlo unCLIP model.
|
||||
|
||||
This scheduler has some minor variations in how it calculates the learned range variance and dynamically
|
||||
re-calculates betas based off the timesteps it is skipping.
|
||||
|
||||
The scheduler also uses a slightly different step ratio when computing timesteps to use for inference.
|
||||
|
||||
See [`~DDPMScheduler`] for more information on DDPM scheduling
|
||||
|
||||
Args:
|
||||
num_train_timesteps (`int`): number of diffusion steps used to train the model.
|
||||
variance_type (`str`):
|
||||
options to clip the variance used when adding noise to the denoised sample. Choose from `fixed_small_log`
|
||||
or `learned_range`.
|
||||
clip_sample (`bool`, default `True`):
|
||||
option to clip predicted sample between `-clip_sample_range` and `clip_sample_range` for numerical
|
||||
stability.
|
||||
clip_sample_range (`float`, default `1.0`):
|
||||
The range to clip the sample between. See `clip_sample`.
|
||||
prediction_type (`str`, default `epsilon`, optional):
|
||||
prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion process)
|
||||
or `sample` (directly predicting the noisy sample`)
|
||||
"""
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
num_train_timesteps: int = 1000,
|
||||
variance_type: str = "fixed_small_log",
|
||||
clip_sample: bool = True,
|
||||
clip_sample_range: Optional[float] = 1.0,
|
||||
prediction_type: str = "epsilon",
|
||||
):
|
||||
# beta scheduler is "squaredcos_cap_v2"
|
||||
self.betas = betas_for_alpha_bar(num_train_timesteps)
|
||||
|
||||
self.alphas = 1.0 - self.betas
|
||||
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
||||
self.one = torch.tensor(1.0)
|
||||
|
||||
# standard deviation of the initial noise distribution
|
||||
self.init_noise_sigma = 1.0
|
||||
|
||||
# setable values
|
||||
self.num_inference_steps = None
|
||||
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy())
|
||||
|
||||
self.variance_type = variance_type
|
||||
|
||||
def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
|
||||
"""
|
||||
Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
|
||||
current timestep.
|
||||
|
||||
Args:
|
||||
sample (`torch.FloatTensor`): input sample
|
||||
timestep (`int`, optional): current timestep
|
||||
|
||||
Returns:
|
||||
`torch.FloatTensor`: scaled input sample
|
||||
"""
|
||||
return sample
|
||||
|
||||
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
||||
"""
|
||||
Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
|
||||
|
||||
Note that this scheduler uses a slightly different step ratio than the other diffusers schedulers. The
|
||||
different step ratio is to mimic the original karlo implementation and does not affect the quality or accuracy
|
||||
of the results.
|
||||
|
||||
Args:
|
||||
num_inference_steps (`int`):
|
||||
the number of diffusion steps used when generating samples with a pre-trained model.
|
||||
"""
|
||||
self.num_inference_steps = num_inference_steps
|
||||
step_ratio = (self.config.num_train_timesteps - 1) / (self.num_inference_steps - 1)
|
||||
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
||||
self.timesteps = torch.from_numpy(timesteps).to(device)
|
||||
|
||||
def _get_variance(self, t, prev_timestep=None, predicted_variance=None, variance_type=None):
|
||||
if prev_timestep is None:
|
||||
prev_timestep = t - 1
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
if prev_timestep == t - 1:
|
||||
beta = self.betas[t]
|
||||
else:
|
||||
beta = 1 - alpha_prod_t / alpha_prod_t_prev
|
||||
|
||||
# For t > 0, compute predicted variance βt (see formula (6) and (7) from https://arxiv.org/pdf/2006.11239.pdf)
|
||||
# and sample from it to get previous sample
|
||||
# x_{t-1} ~ N(pred_prev_sample, variance) == add variance to pred_sample
|
||||
variance = beta_prod_t_prev / beta_prod_t * beta
|
||||
|
||||
if variance_type is None:
|
||||
variance_type = self.config.variance_type
|
||||
|
||||
# hacks - were probably added for training stability
|
||||
if variance_type == "fixed_small_log":
|
||||
variance = torch.log(torch.clamp(variance, min=1e-20))
|
||||
variance = torch.exp(0.5 * variance)
|
||||
elif variance_type == "learned_range":
|
||||
# NOTE difference with DDPM scheduler
|
||||
min_log = variance.log()
|
||||
max_log = beta.log()
|
||||
|
||||
frac = (predicted_variance + 1) / 2
|
||||
variance = frac * max_log + (1 - frac) * min_log
|
||||
|
||||
return variance
|
||||
|
||||
def step(
|
||||
self,
|
||||
model_output: torch.FloatTensor,
|
||||
timestep: int,
|
||||
sample: torch.FloatTensor,
|
||||
prev_timestep: Optional[int] = None,
|
||||
generator=None,
|
||||
return_dict: bool = True,
|
||||
) -> Union[UnCLIPSchedulerOutput, Tuple]:
|
||||
"""
|
||||
Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
|
||||
process from the learned model outputs (most often the predicted noise).
|
||||
|
||||
Args:
|
||||
model_output (`torch.FloatTensor`): direct output from learned diffusion model.
|
||||
timestep (`int`): current discrete timestep in the diffusion chain.
|
||||
sample (`torch.FloatTensor`):
|
||||
current instance of sample being created by diffusion process.
|
||||
prev_timestep (`int`, *optional*): The previous timestep to predict the previous sample at.
|
||||
Used to dynamically compute beta. If not given, `t-1` is used and the pre-computed beta is used.
|
||||
generator: random number generator.
|
||||
return_dict (`bool`): option for returning tuple rather than UnCLIPSchedulerOutput class
|
||||
|
||||
Returns:
|
||||
[`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] or `tuple`:
|
||||
[`~schedulers.scheduling_utils.UnCLIPSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
|
||||
returning a tuple, the first element is the sample tensor.
|
||||
|
||||
"""
|
||||
|
||||
t = timestep
|
||||
|
||||
if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type == "learned_range":
|
||||
model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1)
|
||||
else:
|
||||
predicted_variance = None
|
||||
|
||||
# 1. compute alphas, betas
|
||||
if prev_timestep is None:
|
||||
prev_timestep = t - 1
|
||||
|
||||
alpha_prod_t = self.alphas_cumprod[t]
|
||||
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.one
|
||||
beta_prod_t = 1 - alpha_prod_t
|
||||
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
||||
|
||||
if prev_timestep == t - 1:
|
||||
beta = self.betas[t]
|
||||
alpha = self.alphas[t]
|
||||
else:
|
||||
beta = 1 - alpha_prod_t / alpha_prod_t_prev
|
||||
alpha = 1 - beta
|
||||
|
||||
# 2. compute predicted original sample from predicted noise also called
|
||||
# "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
if self.config.prediction_type == "epsilon":
|
||||
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
||||
elif self.config.prediction_type == "sample":
|
||||
pred_original_sample = model_output
|
||||
else:
|
||||
raise ValueError(
|
||||
f"prediction_type given as {self.config.prediction_type} must be one of `epsilon` or `sample`"
|
||||
" for the UnCLIPScheduler."
|
||||
)
|
||||
|
||||
# 3. Clip "predicted x_0"
|
||||
if self.config.clip_sample:
|
||||
pred_original_sample = torch.clamp(
|
||||
pred_original_sample, -self.config.clip_sample_range, self.config.clip_sample_range
|
||||
)
|
||||
|
||||
# 4. Compute coefficients for pred_original_sample x_0 and current sample x_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_original_sample_coeff = (alpha_prod_t_prev ** (0.5) * beta) / beta_prod_t
|
||||
current_sample_coeff = alpha ** (0.5) * beta_prod_t_prev / beta_prod_t
|
||||
|
||||
# 5. Compute predicted previous sample µ_t
|
||||
# See formula (7) from https://arxiv.org/pdf/2006.11239.pdf
|
||||
pred_prev_sample = pred_original_sample_coeff * pred_original_sample + current_sample_coeff * sample
|
||||
|
||||
# 6. Add noise
|
||||
variance = 0
|
||||
if t > 0:
|
||||
device = model_output.device
|
||||
if device.type == "mps":
|
||||
# randn does not work reproducibly on mps
|
||||
variance_noise = torch.randn(model_output.shape, dtype=model_output.dtype, generator=generator)
|
||||
variance_noise = variance_noise.to(device)
|
||||
else:
|
||||
variance_noise = torch.randn(
|
||||
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
||||
)
|
||||
|
||||
variance = self._get_variance(
|
||||
t,
|
||||
predicted_variance=predicted_variance,
|
||||
prev_timestep=prev_timestep,
|
||||
)
|
||||
|
||||
if self.variance_type == "fixed_small_log":
|
||||
variance = variance
|
||||
elif self.variance_type == "learned_range":
|
||||
variance = (0.5 * variance).exp()
|
||||
else:
|
||||
raise ValueError(
|
||||
f"variance_type given as {self.variance_type} must be one of `fixed_small_log` or `learned_range`"
|
||||
" for the UnCLIPScheduler."
|
||||
)
|
||||
|
||||
variance = variance * variance_noise
|
||||
|
||||
pred_prev_sample = pred_prev_sample + variance
|
||||
|
||||
if not return_dict:
|
||||
return (pred_prev_sample,)
|
||||
|
||||
return UnCLIPSchedulerOutput(prev_sample=pred_prev_sample, pred_original_sample=pred_original_sample)
|
||||
@@ -34,6 +34,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class PriorTransformer(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class Transformer2DModel(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
@@ -512,6 +527,21 @@ class ScoreSdeVeScheduler(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class UnCLIPScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch"])
|
||||
|
||||
|
||||
class VQDiffusionScheduler(metaclass=DummyObject):
|
||||
_backends = ["torch"]
|
||||
|
||||
|
||||
@@ -199,6 +199,21 @@ class StableDiffusionUpscalePipeline(metaclass=DummyObject):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class UnCLIPPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
requires_backends(self, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, *args, **kwargs):
|
||||
requires_backends(cls, ["torch", "transformers"])
|
||||
|
||||
|
||||
class VersatileDiffusionDualGuidedPipeline(metaclass=DummyObject):
|
||||
_backends = ["torch", "transformers"]
|
||||
|
||||
|
||||
0
tests/pipelines/unclip/__init__.py
Normal file
0
tests/pipelines/unclip/__init__.py
Normal file
282
tests/pipelines/unclip/test_unclip.py
Normal file
282
tests/pipelines/unclip/test_unclip.py
Normal file
@@ -0,0 +1,282 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2022 HuggingFace Inc.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import gc
|
||||
import unittest
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from diffusers import PriorTransformer, UnCLIPPipeline, UnCLIPScheduler, UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
|
||||
from diffusers.utils import load_numpy, slow, torch_device
|
||||
from diffusers.utils.testing_utils import require_torch_gpu
|
||||
from transformers import CLIPTextConfig, CLIPTextModelWithProjection, CLIPTokenizer
|
||||
|
||||
|
||||
torch.backends.cuda.matmul.allow_tf32 = False
|
||||
|
||||
|
||||
class UnCLIPPipelineFastTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@property
|
||||
def text_embedder_hidden_size(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def time_input_dim(self):
|
||||
return 32
|
||||
|
||||
@property
|
||||
def block_out_channels_0(self):
|
||||
return self.time_input_dim
|
||||
|
||||
@property
|
||||
def time_embed_dim(self):
|
||||
return self.time_input_dim * 4
|
||||
|
||||
@property
|
||||
def cross_attention_dim(self):
|
||||
return 100
|
||||
|
||||
@property
|
||||
def dummy_tokenizer(self):
|
||||
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
|
||||
return tokenizer
|
||||
|
||||
@property
|
||||
def dummy_text_encoder(self):
|
||||
torch.manual_seed(0)
|
||||
config = CLIPTextConfig(
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
hidden_size=self.text_embedder_hidden_size,
|
||||
projection_dim=self.text_embedder_hidden_size,
|
||||
intermediate_size=37,
|
||||
layer_norm_eps=1e-05,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=5,
|
||||
pad_token_id=1,
|
||||
vocab_size=1000,
|
||||
)
|
||||
return CLIPTextModelWithProjection(config)
|
||||
|
||||
@property
|
||||
def dummy_prior(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"num_attention_heads": 2,
|
||||
"attention_head_dim": 12,
|
||||
"embedding_dim": self.text_embedder_hidden_size,
|
||||
"num_layers": 1,
|
||||
}
|
||||
|
||||
model = PriorTransformer(**model_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_text_proj(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"clip_embeddings_dim": self.text_embedder_hidden_size,
|
||||
"time_embed_dim": self.time_embed_dim,
|
||||
"cross_attention_dim": self.cross_attention_dim,
|
||||
}
|
||||
|
||||
model = UnCLIPTextProjModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_decoder(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model_kwargs = {
|
||||
"sample_size": 64,
|
||||
# RGB in channels
|
||||
"in_channels": 3,
|
||||
# Out channels is double in channels because predicts mean and variance
|
||||
"out_channels": 6,
|
||||
"down_block_types": ("ResnetDownsampleBlock2D", "SimpleCrossAttnDownBlock2D"),
|
||||
"up_block_types": ("SimpleCrossAttnUpBlock2D", "ResnetUpsampleBlock2D"),
|
||||
"mid_block_type": "UNetMidBlock2DSimpleCrossAttn",
|
||||
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
|
||||
"layers_per_block": 1,
|
||||
"cross_attention_dim": self.cross_attention_dim,
|
||||
"attention_head_dim": 4,
|
||||
"resnet_time_scale_shift": "scale_shift",
|
||||
"class_embed_type": "identity",
|
||||
}
|
||||
|
||||
model = UNet2DConditionModel(**model_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_super_res_kwargs(self):
|
||||
return {
|
||||
"sample_size": 128,
|
||||
"layers_per_block": 1,
|
||||
"down_block_types": ("ResnetDownsampleBlock2D", "ResnetDownsampleBlock2D"),
|
||||
"up_block_types": ("ResnetUpsampleBlock2D", "ResnetUpsampleBlock2D"),
|
||||
"block_out_channels": (self.block_out_channels_0, self.block_out_channels_0 * 2),
|
||||
"in_channels": 6,
|
||||
"out_channels": 3,
|
||||
}
|
||||
|
||||
@property
|
||||
def dummy_super_res_first(self):
|
||||
torch.manual_seed(0)
|
||||
|
||||
model = UNet2DModel(**self.dummy_super_res_kwargs)
|
||||
return model
|
||||
|
||||
@property
|
||||
def dummy_super_res_last(self):
|
||||
# seeded differently to get different unet than `self.dummy_super_res_first`
|
||||
torch.manual_seed(1)
|
||||
|
||||
model = UNet2DModel(**self.dummy_super_res_kwargs)
|
||||
return model
|
||||
|
||||
def test_unclip(self):
|
||||
device = "cpu"
|
||||
|
||||
prior = self.dummy_prior
|
||||
decoder = self.dummy_decoder
|
||||
text_proj = self.dummy_text_proj
|
||||
text_encoder = self.dummy_text_encoder
|
||||
tokenizer = self.dummy_tokenizer
|
||||
super_res_first = self.dummy_super_res_first
|
||||
super_res_last = self.dummy_super_res_last
|
||||
|
||||
prior_scheduler = UnCLIPScheduler(
|
||||
variance_type="fixed_small_log",
|
||||
prediction_type="sample",
|
||||
num_train_timesteps=1000,
|
||||
clip_sample_range=5.0,
|
||||
)
|
||||
|
||||
decoder_scheduler = UnCLIPScheduler(
|
||||
variance_type="learned_range",
|
||||
prediction_type="epsilon",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
super_res_scheduler = UnCLIPScheduler(
|
||||
variance_type="fixed_small_log",
|
||||
prediction_type="epsilon",
|
||||
num_train_timesteps=1000,
|
||||
)
|
||||
|
||||
pipe = UnCLIPPipeline(
|
||||
prior=prior,
|
||||
decoder=decoder,
|
||||
text_proj=text_proj,
|
||||
text_encoder=text_encoder,
|
||||
tokenizer=tokenizer,
|
||||
super_res_first=super_res_first,
|
||||
super_res_last=super_res_last,
|
||||
prior_scheduler=prior_scheduler,
|
||||
decoder_scheduler=decoder_scheduler,
|
||||
super_res_scheduler=super_res_scheduler,
|
||||
)
|
||||
pipe = pipe.to(device)
|
||||
|
||||
pipe.set_progress_bar_config(disable=None)
|
||||
|
||||
prompt = "horse"
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
output = pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
prior_num_inference_steps=2,
|
||||
decoder_num_inference_steps=2,
|
||||
super_res_num_inference_steps=2,
|
||||
output_type="np",
|
||||
)
|
||||
image = output.images
|
||||
|
||||
generator = torch.Generator(device=device).manual_seed(0)
|
||||
image_from_tuple = pipe(
|
||||
[prompt],
|
||||
generator=generator,
|
||||
prior_num_inference_steps=2,
|
||||
decoder_num_inference_steps=2,
|
||||
super_res_num_inference_steps=2,
|
||||
output_type="np",
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
image_slice = image[0, -3:, -3:, -1]
|
||||
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
|
||||
|
||||
assert image.shape == (1, 128, 128, 3)
|
||||
|
||||
expected_slice = np.array(
|
||||
[
|
||||
0.0009,
|
||||
0.9997,
|
||||
0.0003,
|
||||
0.9991,
|
||||
0.9967,
|
||||
0.0003,
|
||||
0.9997,
|
||||
0.0003,
|
||||
0.0004,
|
||||
]
|
||||
)
|
||||
|
||||
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
|
||||
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
|
||||
|
||||
|
||||
@slow
|
||||
@require_torch_gpu
|
||||
class UnCLIPPipelineIntegrationTests(unittest.TestCase):
|
||||
def tearDown(self):
|
||||
# clean up the VRAM after each test
|
||||
super().tearDown()
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
def test_unclip_karlo(self):
|
||||
expected_image = load_numpy(
|
||||
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
|
||||
"/karlo_v1_alpha/horse.npy"
|
||||
)
|
||||
|
||||
pipeline = UnCLIPPipeline.from_pretrained("kakaobrain/karlo-v1-alpha")
|
||||
pipeline = pipeline.to(torch_device)
|
||||
pipeline.set_progress_bar_config(disable=None)
|
||||
|
||||
generator = torch.Generator(device=torch_device).manual_seed(0)
|
||||
output = pipeline(
|
||||
"horse",
|
||||
num_images_per_prompt=1,
|
||||
generator=generator,
|
||||
output_type="np",
|
||||
)
|
||||
|
||||
image = output.images[0]
|
||||
|
||||
assert image.shape == (256, 256, 3)
|
||||
assert np.abs(expected_image - image).max() < 1e-2
|
||||
@@ -38,6 +38,7 @@ from diffusers import (
|
||||
LMSDiscreteScheduler,
|
||||
PNDMScheduler,
|
||||
ScoreSdeVeScheduler,
|
||||
UnCLIPScheduler,
|
||||
VQDiffusionScheduler,
|
||||
logging,
|
||||
)
|
||||
@@ -2643,3 +2644,135 @@ class KDPM2AncestralDiscreteSchedulerTest(SchedulerCommonTest):
|
||||
# CUDA
|
||||
assert abs(result_sum.item() - 13913.0332) < 1e-1
|
||||
assert abs(result_mean.item() - 18.1159) < 1e-3
|
||||
|
||||
|
||||
# UnCLIPScheduler is a modified DDPMScheduler with a subset of the configuration.
|
||||
class UnCLIPSchedulerTest(SchedulerCommonTest):
|
||||
scheduler_classes = (UnCLIPScheduler,)
|
||||
|
||||
def get_scheduler_config(self, **kwargs):
|
||||
config = {
|
||||
"num_train_timesteps": 1000,
|
||||
"variance_type": "fixed_small_log",
|
||||
"clip_sample": True,
|
||||
"clip_sample_range": 1.0,
|
||||
"prediction_type": "epsilon",
|
||||
}
|
||||
|
||||
config.update(**kwargs)
|
||||
return config
|
||||
|
||||
def test_timesteps(self):
|
||||
for timesteps in [1, 5, 100, 1000]:
|
||||
self.check_over_configs(num_train_timesteps=timesteps)
|
||||
|
||||
def test_variance_type(self):
|
||||
for variance in ["fixed_small_log", "learned_range"]:
|
||||
self.check_over_configs(variance_type=variance)
|
||||
|
||||
def test_clip_sample(self):
|
||||
for clip_sample in [True, False]:
|
||||
self.check_over_configs(clip_sample=clip_sample)
|
||||
|
||||
def test_clip_sample_range(self):
|
||||
for clip_sample_range in [1, 5, 10, 20]:
|
||||
self.check_over_configs(clip_sample_range=clip_sample_range)
|
||||
|
||||
def test_prediction_type(self):
|
||||
for prediction_type in ["epsilon", "sample"]:
|
||||
self.check_over_configs(prediction_type=prediction_type)
|
||||
|
||||
def test_time_indices(self):
|
||||
for time_step in [0, 500, 999]:
|
||||
for prev_timestep in [None, 5, 100, 250, 500, 750]:
|
||||
if prev_timestep is not None and prev_timestep >= time_step:
|
||||
continue
|
||||
|
||||
self.check_over_forward(time_step=time_step, prev_timestep=prev_timestep)
|
||||
|
||||
def test_variance_fixed_small_log(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(variance_type="fixed_small_log")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(0) - 1.0000e-10)) < 1e-5
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(487) - 0.0549625)) < 1e-5
|
||||
assert torch.sum(torch.abs(scheduler._get_variance(999) - 0.9994987)) < 1e-5
|
||||
|
||||
def test_variance_learned_range(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config(variance_type="learned_range")
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
predicted_variance = 0.5
|
||||
|
||||
assert scheduler._get_variance(1, predicted_variance=predicted_variance) - -10.1712790 < 1e-5
|
||||
assert scheduler._get_variance(487, predicted_variance=predicted_variance) - -5.7998052 < 1e-5
|
||||
assert scheduler._get_variance(999, predicted_variance=predicted_variance) - -0.0010011 < 1e-5
|
||||
|
||||
def test_full_loop(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
# 1. predict noise residual
|
||||
residual = model(sample, t)
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(residual, t, sample, generator=generator).prev_sample
|
||||
|
||||
sample = pred_prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 252.2682495) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3284743) < 1e-3
|
||||
|
||||
def test_full_loop_skip_timesteps(self):
|
||||
scheduler_class = self.scheduler_classes[0]
|
||||
scheduler_config = self.get_scheduler_config()
|
||||
scheduler = scheduler_class(**scheduler_config)
|
||||
|
||||
scheduler.set_timesteps(25)
|
||||
|
||||
timesteps = scheduler.timesteps
|
||||
|
||||
model = self.dummy_model()
|
||||
sample = self.dummy_sample_deter
|
||||
generator = torch.manual_seed(0)
|
||||
|
||||
for i, t in enumerate(timesteps):
|
||||
# 1. predict noise residual
|
||||
residual = model(sample, t)
|
||||
|
||||
if i + 1 == timesteps.shape[0]:
|
||||
prev_timestep = None
|
||||
else:
|
||||
prev_timestep = timesteps[i + 1]
|
||||
|
||||
# 2. predict previous mean of sample x_t-1
|
||||
pred_prev_sample = scheduler.step(
|
||||
residual, t, sample, prev_timestep=prev_timestep, generator=generator
|
||||
).prev_sample
|
||||
|
||||
sample = pred_prev_sample
|
||||
|
||||
result_sum = torch.sum(torch.abs(sample))
|
||||
result_mean = torch.mean(torch.abs(sample))
|
||||
|
||||
assert abs(result_sum.item() - 258.2044983) < 1e-2
|
||||
assert abs(result_mean.item() - 0.3362038) < 1e-3
|
||||
|
||||
def test_trained_betas(self):
|
||||
pass
|
||||
|
||||
def test_add_noise_device(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user