1
0
mirror of https://github.com/huggingface/diffusers.git synced 2026-01-27 17:22:53 +03:00

[Core] move transformer scripts to transformers modules (#6747)

* move transformer scripts to transformers modules

* move transformer model test

* move prior transformer test to  directory

* fix doc path

* correct doc path

* add: __init__.py
This commit is contained in:
Sayak Paul
2024-01-29 22:28:28 +05:30
committed by GitHub
parent 5d8b1987ec
commit 09b7bfce91
28 changed files with 1925 additions and 1754 deletions

View File

@@ -24,4 +24,4 @@ The abstract from the paper is:
## PriorTransformerOutput
[[autodoc]] models.prior_transformer.PriorTransformerOutput
[[autodoc]] models.transformers.prior_transformer.PriorTransformerOutput

View File

@@ -38,4 +38,4 @@ It is assumed one of the input classes is the masked latent pixel. The predicted
## Transformer2DModelOutput
[[autodoc]] models.transformer_2d.Transformer2DModelOutput
[[autodoc]] models.transformers.transformer_2d.Transformer2DModelOutput

View File

@@ -16,8 +16,8 @@ A Transformer model for video-like data.
## TransformerTemporalModel
[[autodoc]] models.transformer_temporal.TransformerTemporalModel
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModel
## TransformerTemporalModelOutput
[[autodoc]] models.transformer_temporal.TransformerTemporalModelOutput
[[autodoc]] models.transformers.transformer_temporal.TransformerTemporalModelOutput

View File

@@ -6,7 +6,7 @@ from accelerate import load_checkpoint_and_dispatch
from transformers import CLIPTextModelWithProjection, CLIPTokenizer
from diffusers import UnCLIPPipeline, UNet2DConditionModel, UNet2DModel
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.transformers.prior_transformer import PriorTransformer
from diffusers.pipelines.unclip.text_proj import UnCLIPTextProjModel
from diffusers.schedulers.scheduling_unclip import UnCLIPScheduler

View File

@@ -6,7 +6,7 @@ import torch
from accelerate import load_checkpoint_and_dispatch
from diffusers import UNet2DConditionModel
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.transformers.prior_transformer import PriorTransformer
from diffusers.models.vq_model import VQModel

View File

@@ -4,7 +4,7 @@ import tempfile
import torch
from accelerate import load_checkpoint_and_dispatch
from diffusers.models.prior_transformer import PriorTransformer
from diffusers.models.transformers.prior_transformer import PriorTransformer
from diffusers.pipelines.shap_e import ShapERenderer

View File

@@ -35,10 +35,10 @@ if is_torch_available():
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
_import_structure["embeddings"] = ["ImageProjection"]
_import_structure["modeling_utils"] = ["ModelMixin"]
_import_structure["prior_transformer"] = ["PriorTransformer"]
_import_structure["t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
_import_structure["unets.unet_1d"] = ["UNet1DModel"]
_import_structure["unets.unet_2d"] = ["UNet2DModel"]
_import_structure["unets.unet_2d_condition"] = ["UNet2DConditionModel"]
@@ -66,13 +66,15 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
ConsistencyDecoderVAE,
)
from .controlnet import ControlNetModel
from .dual_transformer_2d import DualTransformer2DModel
from .embeddings import ImageProjection
from .modeling_utils import ModelMixin
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel
from .transformers import (
DualTransformer2DModel,
PriorTransformer,
T5FilmDecoder,
Transformer2DModel,
TransformerTemporalModel,
)
from .unets import (
Kandinsky3UNet,
MotionAdapter,

View File

@@ -11,145 +11,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch import nn
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
from ..utils import deprecate
from .transformers.dual_transformer_2d import DualTransformer2DModel
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(
self,
hidden_states,
encoder_hidden_states,
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in Attention.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)
class DualTransformer2DModel(DualTransformer2DModel):
deprecation_message = "Importing `DualTransformer2DModel` from `diffusers.models.dual_transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel`, instead."
deprecate("DualTransformer2DModel", "0.29", deprecation_message)

View File

@@ -1,380 +1,12 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union
from ..utils import deprecate
from .transformers.prior_transformer import PriorTransformer, PriorTransformerOutput
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ..utils import BaseOutput
from .attention import BasicTransformerBlock
from .attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
class PriorTransformerOutput(PriorTransformerOutput):
deprecation_message = "Importing `PriorTransformerOutput` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformerOutput`, instead."
deprecate("PriorTransformerOutput", "0.29", deprecation_message)
@dataclass
class PriorTransformerOutput(BaseOutput):
"""
The output of [`PriorTransformer`].
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""
predicted_image_embedding: torch.FloatTensor
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
"""
A Prior Transformer model.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
num_embeddings (`int`, *optional*, defaults to 77):
The number of embeddings of the model input `hidden_states`
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
The activation function to use to create timestep embeddings.
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
passing to Transformer blocks. Set it to `None` if normalization is not needed.
embedding_proj_norm_type (`str`, *optional*, defaults to None):
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
needed.
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
`encoder_hidden_states` is `None`.
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim`
embedding_proj_dim (`int`, *optional*, default to None):
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
clip_embed_dim (`int`, *optional*, default to None):
The dimension of the output. If None, will be set to `embedding_dim`.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
num_layers: int = 20,
embedding_dim: int = 768,
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
time_embed_act_fn: str = "silu",
norm_in_type: Optional[str] = None, # layer
embedding_proj_norm_type: Optional[str] = None, # layer
encoder_hid_proj_type: Optional[str] = "linear", # linear
added_emb_type: Optional[str] = "prd", # prd
time_embed_dim: Optional[int] = None,
embedding_proj_dim: Optional[int] = None,
clip_embed_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings
time_embed_dim = time_embed_dim or inner_dim
embedding_proj_dim = embedding_proj_dim or embedding_dim
clip_embed_dim = clip_embed_dim or embedding_dim
self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
self.proj_in = nn.Linear(embedding_dim, inner_dim)
if embedding_proj_norm_type is None:
self.embedding_proj_norm = None
elif embedding_proj_norm_type == "layer":
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
else:
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
if encoder_hid_proj_type is None:
self.encoder_hidden_states_proj = None
elif encoder_hid_proj_type == "linear":
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
else:
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
if added_emb_type == "prd":
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
elif added_emb_type is None:
self.prd_embedding = None
else:
raise ValueError(
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn="gelu",
attention_bias=True,
)
for d in range(num_layers)
]
)
if norm_in_type == "layer":
self.norm_in = nn.LayerNorm(inner_dim)
elif norm_in_type is None:
self.norm_in = None
else:
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
"""
The [`PriorTransformer`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The currently predicted image embeddings.
timestep (`torch.LongTensor`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.
Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
timesteps_projected = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
if self.embedding_proj_norm is not None:
proj_embedding = self.embedding_proj_norm(proj_embedding)
proj_embeddings = self.embedding_proj(proj_embedding)
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
hidden_states = self.proj_in(hidden_states)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
additional_embeds = []
additional_embeddings_len = 0
if encoder_hidden_states is not None:
additional_embeds.append(encoder_hidden_states)
additional_embeddings_len += encoder_hidden_states.shape[1]
if len(proj_embeddings.shape) == 2:
proj_embeddings = proj_embeddings[:, None, :]
if len(hidden_states.shape) == 2:
hidden_states = hidden_states[:, None, :]
additional_embeds = additional_embeds + [
proj_embeddings,
time_embeddings[:, None, :],
hidden_states,
]
if self.prd_embedding is not None:
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
additional_embeds.append(prd_embedding)
hidden_states = torch.cat(
additional_embeds,
dim=1,
)
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
if positional_embeddings.shape[1] < hidden_states.shape[1]:
positional_embeddings = F.pad(
positional_embeddings,
(
0,
0,
additional_embeddings_len,
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
),
value=0.0,
)
hidden_states = hidden_states + positional_embeddings
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm_out(hidden_states)
if self.prd_embedding is not None:
hidden_states = hidden_states[:, -1]
else:
hidden_states = hidden_states[:, additional_embeddings_len:]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
if not return_dict:
return (predicted_image_embedding,)
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
def post_process_latents(self, prior_latents):
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents
class PriorTransformer(PriorTransformer):
deprecation_message = "Importing `PriorTransformer` from `diffusers.models.prior_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.prior_transformer import PriorTransformer`, instead."
deprecate("PriorTransformer", "0.29", deprecation_message)

View File

@@ -11,428 +11,60 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
from ..utils import deprecate
from .transformers.t5_film_transformer import (
DecoderLayer,
NewGELUActivation,
T5DenseGatedActDense,
T5FilmDecoder,
T5FiLMLayer,
T5LayerCrossAttention,
T5LayerFFCond,
T5LayerNorm,
T5LayerSelfAttentionCond,
)
import torch
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from .attention_processor import Attention
from .embeddings import get_timestep_embedding
from .modeling_utils import ModelMixin
class T5FilmDecoder(T5FilmDecoder):
deprecation_message = "Importing `T5FilmDecoder` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FilmDecoder`, instead."
deprecate("T5FilmDecoder", "0.29", deprecation_message)
class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
class DecoderLayer(DecoderLayer):
deprecation_message = "Importing `DecoderLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import DecoderLayer`, instead."
deprecate("DecoderLayer", "0.29", deprecation_message)
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""
@register_to_config
def __init__(
self,
input_dims: int = 128,
targets_length: int = 256,
max_decoder_noise_time: float = 2000.0,
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 2048,
dropout_rate: float = 0.1,
):
super().__init__()
class T5LayerSelfAttentionCond(T5LayerSelfAttentionCond):
deprecation_message = "Importing `T5LayerSelfAttentionCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerSelfAttentionCond`, instead."
deprecate("T5LayerSelfAttentionCond", "0.29", deprecation_message)
self.conditioning_emb = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias=False),
nn.SiLU(),
nn.Linear(d_model * 4, d_model * 4, bias=False),
nn.SiLU(),
)
self.position_encoding = nn.Embedding(targets_length, d_model)
self.position_encoding.weight.requires_grad = False
class T5LayerCrossAttention(T5LayerCrossAttention):
deprecation_message = "Importing `T5LayerCrossAttention` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerCrossAttention`, instead."
deprecate("T5LayerCrossAttention", "0.29", deprecation_message)
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
self.dropout = nn.Dropout(p=dropout_rate)
class T5LayerFFCond(T5LayerFFCond):
deprecation_message = "Importing `T5LayerFFCond` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerFFCond`, instead."
deprecate("T5LayerFFCond", "0.29", deprecation_message)
self.decoders = nn.ModuleList()
for lyr_num in range(num_layers):
# FiLM conditional T5 decoder
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
self.decoders.append(lyr)
self.decoder_norm = T5LayerNorm(d_model)
class T5DenseGatedActDense(T5DenseGatedActDense):
deprecation_message = "Importing `T5DenseGatedActDense` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5DenseGatedActDense`, instead."
deprecate("T5DenseGatedActDense", "0.29", deprecation_message)
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)
class T5LayerNorm(T5LayerNorm):
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5LayerNorm`, instead."
deprecate("T5LayerNorm", "0.29", deprecation_message)
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
batch, _, _ = decoder_input_tokens.shape
assert decoder_noise_time.shape == (batch,)
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
time_steps = get_timestep_embedding(
decoder_noise_time * self.config.max_decoder_noise_time,
embedding_dim=self.config.d_model,
max_period=self.config.max_decoder_noise_time,
).to(dtype=self.dtype)
class NewGELUActivation(NewGELUActivation):
deprecation_message = "Importing `T5LayerNorm` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import NewGELUActivation`, instead."
deprecate("NewGELUActivation", "0.29", deprecation_message)
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
seq_length = decoder_input_tokens.shape[1]
# If we want to use relative positions for audio context, we can just offset
# this sequence by the length of encodings_and_masks.
decoder_positions = torch.broadcast_to(
torch.arange(seq_length, device=decoder_input_tokens.device),
(batch, seq_length),
)
position_encodings = self.position_encoding(decoder_positions)
inputs = self.continuous_inputs_projection(decoder_input_tokens)
inputs += position_encodings
y = self.dropout(inputs)
# decoder: No padding present.
decoder_mask = torch.ones(
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
)
# Translate encoding masks to encoder-decoder masks.
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
# cross attend style: concat encodings
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
for lyr in self.decoders:
y = lyr(
y,
conditioning_emb=conditioning_emb,
encoder_hidden_states=encoded,
encoder_attention_mask=encoder_decoder_mask,
)[0]
y = self.decoder_norm(y)
y = self.post_dropout(y)
spec_out = self.spec_out(y)
return spec_out
class DecoderLayer(nn.Module):
r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__()
self.layer = nn.ModuleList()
# cond self attention: layer 0
self.layer.append(
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
)
# cross attention: layer 1
self.layer.append(
T5LayerCrossAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
layer_norm_epsilon=layer_norm_epsilon,
)
)
# Film Cond MLP + dropout: last layer
self.layer.append(
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
)
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
attention_mask=attention_mask,
)
if encoder_hidden_states is not None:
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
encoder_hidden_states.dtype
)
hidden_states = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_extended_attention_mask,
)
# Apply Film Conditional Feed Forward layer
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
return (hidden_states,)
class T5LayerSelfAttentionCond(nn.Module):
r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
# Self-attention block
attention_output = self.attention(normed_hidden_states)
hidden_states = hidden_states + self.dropout(attention_output)
return hidden_states
class T5LayerCrossAttention(nn.Module):
r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
encoder_hidden_states=key_value_states,
attention_mask=attention_mask.squeeze(1),
)
layer_output = hidden_states + self.dropout(attention_output)
return layer_output
class T5LayerFFCond(nn.Module):
r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5DenseGatedActDense(nn.Module):
r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
self.wo = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerNorm(nn.Module):
r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class T5FiLMLayer(nn.Module):
"""
T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
return x
class T5FiLMLayer(T5FiLMLayer):
deprecation_message = "Importing `T5FiLMLayer` from `diffusers.models.t5_film_transformer` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.t5_film_transformer import T5FiLMLayer`, instead."
deprecate("T5FiLMLayer", "0.29", deprecation_message)

View File

@@ -11,449 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..models.embeddings import ImagePositionalEmbeddings
from ..utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from .attention import BasicTransformerBlock
from .embeddings import PatchEmbed, PixArtAlphaTextProjection
from .lora import LoRACompatibleConv, LoRACompatibleLinear
from .modeling_utils import ModelMixin
from .normalization import AdaLayerNormSingle
from ..utils import deprecate
from .transformers.transformer_2d import Transformer2DModel, Transformer2DModelOutput
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModelOutput(Transformer2DModelOutput):
deprecation_message = "Importing `Transformer2DModelOutput` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModelOutput`, instead."
deprecate("Transformer2DModelOutput", "0.29", deprecation_message)
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif self.is_input_patches and norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
class Transformer2DModel(Transformer2DModel):
deprecation_message = "Importing `Transformer2DModel` from `diffusers.models.transformer_2d` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_2d import Transformer2DModel`, instead."
deprecate("Transformer2DModel", "0.29", deprecation_message)

View File

@@ -11,369 +11,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from torch import nn
from ..configuration_utils import ConfigMixin, register_to_config
from ..utils import BaseOutput
from .attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from .embeddings import TimestepEmbedding, Timesteps
from .modeling_utils import ModelMixin
from .resnet import AlphaBlender
from ..utils import deprecate
from .transformers.transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
TransformerTemporalModelOutput,
)
@dataclass
class TransformerTemporalModelOutput(BaseOutput):
"""
The output of [`TransformerTemporalModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input.
"""
sample: torch.FloatTensor
class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead."
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> TransformerTemporalModelOutput:
"""
The [`TransformerTemporal`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
num_frames (`int`, *optional*, defaults to 1):
The number of frames to be processed per batch. This is used to reshape the hidden states.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
residual = hidden_states
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
.permute(0, 3, 4, 1, 2)
.contiguous()
)
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
class TransformerTemporalModel(TransformerTemporalModel):
deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead."
deprecate("TransformerTemporalModel", "0.29", deprecation_message)
class TransformerSpatioTemporalModel(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
out_channels (`int`, *optional*):
The number of channels in the output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input hidden_states.
num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, _, height, width = hidden_states.shape
num_frames = image_only_indicator.shape[-1]
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[None, :].reshape(
batch_size, num_frames, -1, time_context.shape[-1]
)[:, 0]
time_context = time_context_first_timestep[None, :].broadcast_to(
height * width, batch_size, 1, time_context.shape[-1]
)
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead."
deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)

View File

@@ -0,0 +1,9 @@
from ...utils import is_torch_available
if is_torch_available():
from .dual_transformer_2d import DualTransformer2DModel
from .prior_transformer import PriorTransformer
from .t5_film_transformer import T5FilmDecoder
from .transformer_2d import Transformer2DModel
from .transformer_temporal import TransformerTemporalModel

View File

@@ -0,0 +1,155 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
from torch import nn
from .transformer_2d import Transformer2DModel, Transformer2DModelOutput
class DualTransformer2DModel(nn.Module):
"""
Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
Pass if the input is continuous. The number of channels in the input and output.
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
Note that this is fixed at training time as it is used for learning a number of position embeddings. See
`ImagePositionalEmbeddings`.
num_vector_embeds (`int`, *optional*):
Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
The number of diffusion steps used during training. Note that this is fixed at training time as it is used
to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
up to but not more than steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the TransformerBlocks' attention should contain a bias parameter.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
):
super().__init__()
self.transformers = nn.ModuleList(
[
Transformer2DModel(
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
in_channels=in_channels,
num_layers=num_layers,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attention_bias=attention_bias,
sample_size=sample_size,
num_vector_embeds=num_vector_embeds,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
)
for _ in range(2)
]
)
# Variables that can be set by a pipeline:
# The ratio of transformer1 to transformer2's output states to be combined during inference
self.mix_ratio = 0.5
# The shape of `encoder_hidden_states` is expected to be
# `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
self.condition_lengths = [77, 257]
# Which transformer to use to encode which condition.
# E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
self.transformer_index_for_condition = [1, 0]
def forward(
self,
hidden_states,
encoder_hidden_states,
timestep=None,
attention_mask=None,
cross_attention_kwargs=None,
return_dict: bool = True,
):
"""
Args:
hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.long`, *optional*):
Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
attention_mask (`torch.FloatTensor`, *optional*):
Optional attention mask to be applied in Attention.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
[`~models.transformer_2d.Transformer2DModelOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
input_states = hidden_states
encoded_states = []
tokens_start = 0
# attention_mask is not used yet
for i in range(2):
# for each of the two transformers, pass the corresponding condition tokens
condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
transformer_index = self.transformer_index_for_condition[i]
encoded_state = self.transformers[transformer_index](
input_states,
encoder_hidden_states=condition_state,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]
encoded_states.append(encoded_state - input_states)
tokens_start += self.condition_lengths[i]
output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
output_states = output_states + input_states
if not return_dict:
return (output_states,)
return Transformer2DModelOutput(sample=output_states)

View File

@@ -0,0 +1,380 @@
from dataclasses import dataclass
from typing import Dict, Optional, Union
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
from ...utils import BaseOutput
from ..attention import BasicTransformerBlock
from ..attention_processor import (
ADDED_KV_ATTENTION_PROCESSORS,
CROSS_ATTENTION_PROCESSORS,
AttentionProcessor,
AttnAddedKVProcessor,
AttnProcessor,
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
@dataclass
class PriorTransformerOutput(BaseOutput):
"""
The output of [`PriorTransformer`].
Args:
predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The predicted CLIP image embedding conditioned on the CLIP text embedding input.
"""
predicted_image_embedding: torch.FloatTensor
class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
"""
A Prior Transformer model.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 32): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_layers (`int`, *optional*, defaults to 20): The number of layers of Transformer blocks to use.
embedding_dim (`int`, *optional*, defaults to 768): The dimension of the model input `hidden_states`
num_embeddings (`int`, *optional*, defaults to 77):
The number of embeddings of the model input `hidden_states`
additional_embeddings (`int`, *optional*, defaults to 4): The number of additional tokens appended to the
projected `hidden_states`. The actual length of the used `hidden_states` is `num_embeddings +
additional_embeddings`.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
time_embed_act_fn (`str`, *optional*, defaults to 'silu'):
The activation function to use to create timestep embeddings.
norm_in_type (`str`, *optional*, defaults to None): The normalization layer to apply on hidden states before
passing to Transformer blocks. Set it to `None` if normalization is not needed.
embedding_proj_norm_type (`str`, *optional*, defaults to None):
The normalization layer to apply on the input `proj_embedding`. Set it to `None` if normalization is not
needed.
encoder_hid_proj_type (`str`, *optional*, defaults to `linear`):
The projection layer to apply on the input `encoder_hidden_states`. Set it to `None` if
`encoder_hidden_states` is `None`.
added_emb_type (`str`, *optional*, defaults to `prd`): Additional embeddings to condition the model.
Choose from `prd` or `None`. if choose `prd`, it will prepend a token indicating the (quantized) dot
product between the text embedding and image embedding as proposed in the unclip paper
https://arxiv.org/abs/2204.06125 If it is `None`, no additional embeddings will be prepended.
time_embed_dim (`int, *optional*, defaults to None): The dimension of timestep embeddings.
If None, will be set to `num_attention_heads * attention_head_dim`
embedding_proj_dim (`int`, *optional*, default to None):
The dimension of `proj_embedding`. If None, will be set to `embedding_dim`.
clip_embed_dim (`int`, *optional*, default to None):
The dimension of the output. If None, will be set to `embedding_dim`.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 32,
attention_head_dim: int = 64,
num_layers: int = 20,
embedding_dim: int = 768,
num_embeddings=77,
additional_embeddings=4,
dropout: float = 0.0,
time_embed_act_fn: str = "silu",
norm_in_type: Optional[str] = None, # layer
embedding_proj_norm_type: Optional[str] = None, # layer
encoder_hid_proj_type: Optional[str] = "linear", # linear
added_emb_type: Optional[str] = "prd", # prd
time_embed_dim: Optional[int] = None,
embedding_proj_dim: Optional[int] = None,
clip_embed_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.additional_embeddings = additional_embeddings
time_embed_dim = time_embed_dim or inner_dim
embedding_proj_dim = embedding_proj_dim or embedding_dim
clip_embed_dim = clip_embed_dim or embedding_dim
self.time_proj = Timesteps(inner_dim, True, 0)
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, out_dim=inner_dim, act_fn=time_embed_act_fn)
self.proj_in = nn.Linear(embedding_dim, inner_dim)
if embedding_proj_norm_type is None:
self.embedding_proj_norm = None
elif embedding_proj_norm_type == "layer":
self.embedding_proj_norm = nn.LayerNorm(embedding_proj_dim)
else:
raise ValueError(f"unsupported embedding_proj_norm_type: {embedding_proj_norm_type}")
self.embedding_proj = nn.Linear(embedding_proj_dim, inner_dim)
if encoder_hid_proj_type is None:
self.encoder_hidden_states_proj = None
elif encoder_hid_proj_type == "linear":
self.encoder_hidden_states_proj = nn.Linear(embedding_dim, inner_dim)
else:
raise ValueError(f"unsupported encoder_hid_proj_type: {encoder_hid_proj_type}")
self.positional_embedding = nn.Parameter(torch.zeros(1, num_embeddings + additional_embeddings, inner_dim))
if added_emb_type == "prd":
self.prd_embedding = nn.Parameter(torch.zeros(1, 1, inner_dim))
elif added_emb_type is None:
self.prd_embedding = None
else:
raise ValueError(
f"`added_emb_type`: {added_emb_type} is not supported. Make sure to choose one of `'prd'` or `None`."
)
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
activation_fn="gelu",
attention_bias=True,
)
for d in range(num_layers)
]
)
if norm_in_type == "layer":
self.norm_in = nn.LayerNorm(inner_dim)
elif norm_in_type is None:
self.norm_in = None
else:
raise ValueError(f"Unsupported norm_in_type: {norm_in_type}.")
self.norm_out = nn.LayerNorm(inner_dim)
self.proj_to_clip_embeddings = nn.Linear(inner_dim, clip_embed_dim)
causal_attention_mask = torch.full(
[num_embeddings + additional_embeddings, num_embeddings + additional_embeddings], -10000.0
)
causal_attention_mask.triu_(1)
causal_attention_mask = causal_attention_mask[None, ...]
self.register_buffer("causal_attention_mask", causal_attention_mask, persistent=False)
self.clip_mean = nn.Parameter(torch.zeros(1, clip_embed_dim))
self.clip_std = nn.Parameter(torch.zeros(1, clip_embed_dim))
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnAddedKVProcessor()
elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
processor = AttnProcessor()
else:
raise ValueError(
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
)
self.set_attn_processor(processor)
def forward(
self,
hidden_states,
timestep: Union[torch.Tensor, float, int],
proj_embedding: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.BoolTensor] = None,
return_dict: bool = True,
):
"""
The [`PriorTransformer`] forward method.
Args:
hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
The currently predicted image embeddings.
timestep (`torch.LongTensor`):
Current denoising step.
proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
Projected embedding vector the denoising process is conditioned on.
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
Hidden states of the text embeddings the denoising process is conditioned on.
attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
Text mask for the text embeddings.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.prior_transformer.PriorTransformerOutput`] instead of a plain
tuple.
Returns:
[`~models.prior_transformer.PriorTransformerOutput`] or `tuple`:
If return_dict is True, a [`~models.prior_transformer.PriorTransformerOutput`] is returned, otherwise a
tuple is returned where the first element is the sample tensor.
"""
batch_size = hidden_states.shape[0]
timesteps = timestep
if not torch.is_tensor(timesteps):
timesteps = torch.tensor([timesteps], dtype=torch.long, device=hidden_states.device)
elif torch.is_tensor(timesteps) and len(timesteps.shape) == 0:
timesteps = timesteps[None].to(hidden_states.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps * torch.ones(batch_size, dtype=timesteps.dtype, device=timesteps.device)
timesteps_projected = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might be fp16, so we need to cast here.
timesteps_projected = timesteps_projected.to(dtype=self.dtype)
time_embeddings = self.time_embedding(timesteps_projected)
if self.embedding_proj_norm is not None:
proj_embedding = self.embedding_proj_norm(proj_embedding)
proj_embeddings = self.embedding_proj(proj_embedding)
if self.encoder_hidden_states_proj is not None and encoder_hidden_states is not None:
encoder_hidden_states = self.encoder_hidden_states_proj(encoder_hidden_states)
elif self.encoder_hidden_states_proj is not None and encoder_hidden_states is None:
raise ValueError("`encoder_hidden_states_proj` requires `encoder_hidden_states` to be set")
hidden_states = self.proj_in(hidden_states)
positional_embeddings = self.positional_embedding.to(hidden_states.dtype)
additional_embeds = []
additional_embeddings_len = 0
if encoder_hidden_states is not None:
additional_embeds.append(encoder_hidden_states)
additional_embeddings_len += encoder_hidden_states.shape[1]
if len(proj_embeddings.shape) == 2:
proj_embeddings = proj_embeddings[:, None, :]
if len(hidden_states.shape) == 2:
hidden_states = hidden_states[:, None, :]
additional_embeds = additional_embeds + [
proj_embeddings,
time_embeddings[:, None, :],
hidden_states,
]
if self.prd_embedding is not None:
prd_embedding = self.prd_embedding.to(hidden_states.dtype).expand(batch_size, -1, -1)
additional_embeds.append(prd_embedding)
hidden_states = torch.cat(
additional_embeds,
dim=1,
)
# Allow positional_embedding to not include the `addtional_embeddings` and instead pad it with zeros for these additional tokens
additional_embeddings_len = additional_embeddings_len + proj_embeddings.shape[1] + 1
if positional_embeddings.shape[1] < hidden_states.shape[1]:
positional_embeddings = F.pad(
positional_embeddings,
(
0,
0,
additional_embeddings_len,
self.prd_embedding.shape[1] if self.prd_embedding is not None else 0,
),
value=0.0,
)
hidden_states = hidden_states + positional_embeddings
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
if self.norm_in is not None:
hidden_states = self.norm_in(hidden_states)
for block in self.transformer_blocks:
hidden_states = block(hidden_states, attention_mask=attention_mask)
hidden_states = self.norm_out(hidden_states)
if self.prd_embedding is not None:
hidden_states = hidden_states[:, -1]
else:
hidden_states = hidden_states[:, additional_embeddings_len:]
predicted_image_embedding = self.proj_to_clip_embeddings(hidden_states)
if not return_dict:
return (predicted_image_embedding,)
return PriorTransformerOutput(predicted_image_embedding=predicted_image_embedding)
def post_process_latents(self, prior_latents):
prior_latents = (prior_latents * self.clip_std) + self.clip_mean
return prior_latents

View File

@@ -0,0 +1,438 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional, Tuple
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ..attention_processor import Attention
from ..embeddings import get_timestep_embedding
from ..modeling_utils import ModelMixin
class T5FilmDecoder(ModelMixin, ConfigMixin):
r"""
T5 style decoder with FiLM conditioning.
Args:
input_dims (`int`, *optional*, defaults to `128`):
The number of input dimensions.
targets_length (`int`, *optional*, defaults to `256`):
The length of the targets.
d_model (`int`, *optional*, defaults to `768`):
Size of the input hidden states.
num_layers (`int`, *optional*, defaults to `12`):
The number of `DecoderLayer`'s to use.
num_heads (`int`, *optional*, defaults to `12`):
The number of attention heads to use.
d_kv (`int`, *optional*, defaults to `64`):
Size of the key-value projection vectors.
d_ff (`int`, *optional*, defaults to `2048`):
The number of dimensions in the intermediate feed-forward layer of `DecoderLayer`'s.
dropout_rate (`float`, *optional*, defaults to `0.1`):
Dropout probability.
"""
@register_to_config
def __init__(
self,
input_dims: int = 128,
targets_length: int = 256,
max_decoder_noise_time: float = 2000.0,
d_model: int = 768,
num_layers: int = 12,
num_heads: int = 12,
d_kv: int = 64,
d_ff: int = 2048,
dropout_rate: float = 0.1,
):
super().__init__()
self.conditioning_emb = nn.Sequential(
nn.Linear(d_model, d_model * 4, bias=False),
nn.SiLU(),
nn.Linear(d_model * 4, d_model * 4, bias=False),
nn.SiLU(),
)
self.position_encoding = nn.Embedding(targets_length, d_model)
self.position_encoding.weight.requires_grad = False
self.continuous_inputs_projection = nn.Linear(input_dims, d_model, bias=False)
self.dropout = nn.Dropout(p=dropout_rate)
self.decoders = nn.ModuleList()
for lyr_num in range(num_layers):
# FiLM conditional T5 decoder
lyr = DecoderLayer(d_model=d_model, d_kv=d_kv, num_heads=num_heads, d_ff=d_ff, dropout_rate=dropout_rate)
self.decoders.append(lyr)
self.decoder_norm = T5LayerNorm(d_model)
self.post_dropout = nn.Dropout(p=dropout_rate)
self.spec_out = nn.Linear(d_model, input_dims, bias=False)
def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
return mask.unsqueeze(-3)
def forward(self, encodings_and_masks, decoder_input_tokens, decoder_noise_time):
batch, _, _ = decoder_input_tokens.shape
assert decoder_noise_time.shape == (batch,)
# decoder_noise_time is in [0, 1), so rescale to expected timing range.
time_steps = get_timestep_embedding(
decoder_noise_time * self.config.max_decoder_noise_time,
embedding_dim=self.config.d_model,
max_period=self.config.max_decoder_noise_time,
).to(dtype=self.dtype)
conditioning_emb = self.conditioning_emb(time_steps).unsqueeze(1)
assert conditioning_emb.shape == (batch, 1, self.config.d_model * 4)
seq_length = decoder_input_tokens.shape[1]
# If we want to use relative positions for audio context, we can just offset
# this sequence by the length of encodings_and_masks.
decoder_positions = torch.broadcast_to(
torch.arange(seq_length, device=decoder_input_tokens.device),
(batch, seq_length),
)
position_encodings = self.position_encoding(decoder_positions)
inputs = self.continuous_inputs_projection(decoder_input_tokens)
inputs += position_encodings
y = self.dropout(inputs)
# decoder: No padding present.
decoder_mask = torch.ones(
decoder_input_tokens.shape[:2], device=decoder_input_tokens.device, dtype=inputs.dtype
)
# Translate encoding masks to encoder-decoder masks.
encodings_and_encdec_masks = [(x, self.encoder_decoder_mask(decoder_mask, y)) for x, y in encodings_and_masks]
# cross attend style: concat encodings
encoded = torch.cat([x[0] for x in encodings_and_encdec_masks], dim=1)
encoder_decoder_mask = torch.cat([x[1] for x in encodings_and_encdec_masks], dim=-1)
for lyr in self.decoders:
y = lyr(
y,
conditioning_emb=conditioning_emb,
encoder_hidden_states=encoded,
encoder_attention_mask=encoder_decoder_mask,
)[0]
y = self.decoder_norm(y)
y = self.post_dropout(y)
spec_out = self.spec_out(y)
return spec_out
class DecoderLayer(nn.Module):
r"""
T5 decoder layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`, *optional*, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(
self, d_model: int, d_kv: int, num_heads: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float = 1e-6
):
super().__init__()
self.layer = nn.ModuleList()
# cond self attention: layer 0
self.layer.append(
T5LayerSelfAttentionCond(d_model=d_model, d_kv=d_kv, num_heads=num_heads, dropout_rate=dropout_rate)
)
# cross attention: layer 1
self.layer.append(
T5LayerCrossAttention(
d_model=d_model,
d_kv=d_kv,
num_heads=num_heads,
dropout_rate=dropout_rate,
layer_norm_epsilon=layer_norm_epsilon,
)
)
# Film Cond MLP + dropout: last layer
self.layer.append(
T5LayerFFCond(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate, layer_norm_epsilon=layer_norm_epsilon)
)
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
encoder_decoder_position_bias=None,
) -> Tuple[torch.FloatTensor]:
hidden_states = self.layer[0](
hidden_states,
conditioning_emb=conditioning_emb,
attention_mask=attention_mask,
)
if encoder_hidden_states is not None:
encoder_extended_attention_mask = torch.where(encoder_attention_mask > 0, 0, -1e10).to(
encoder_hidden_states.dtype
)
hidden_states = self.layer[1](
hidden_states,
key_value_states=encoder_hidden_states,
attention_mask=encoder_extended_attention_mask,
)
# Apply Film Conditional Feed Forward layer
hidden_states = self.layer[-1](hidden_states, conditioning_emb)
return (hidden_states,)
class T5LayerSelfAttentionCond(nn.Module):
r"""
T5 style self-attention layer with conditioning.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float):
super().__init__()
self.layer_norm = T5LayerNorm(d_model)
self.FiLMLayer = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
hidden_states: torch.FloatTensor,
conditioning_emb: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# pre_self_attention_layer_norm
normed_hidden_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
normed_hidden_states = self.FiLMLayer(normed_hidden_states, conditioning_emb)
# Self-attention block
attention_output = self.attention(normed_hidden_states)
hidden_states = hidden_states + self.dropout(attention_output)
return hidden_states
class T5LayerCrossAttention(nn.Module):
r"""
T5 style cross-attention layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_kv (`int`):
Size of the key-value projection vectors.
num_heads (`int`):
Number of attention heads.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_kv: int, num_heads: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.attention = Attention(query_dim=d_model, heads=num_heads, dim_head=d_kv, out_bias=False, scale_qk=False)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self,
hidden_states: torch.FloatTensor,
key_value_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
normed_hidden_states = self.layer_norm(hidden_states)
attention_output = self.attention(
normed_hidden_states,
encoder_hidden_states=key_value_states,
attention_mask=attention_mask.squeeze(1),
)
layer_output = hidden_states + self.dropout(attention_output)
return layer_output
class T5LayerFFCond(nn.Module):
r"""
T5 style feed-forward conditional layer.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
layer_norm_epsilon (`float`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float, layer_norm_epsilon: float):
super().__init__()
self.DenseReluDense = T5DenseGatedActDense(d_model=d_model, d_ff=d_ff, dropout_rate=dropout_rate)
self.film = T5FiLMLayer(in_features=d_model * 4, out_features=d_model)
self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
self.dropout = nn.Dropout(dropout_rate)
def forward(
self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
) -> torch.FloatTensor:
forwarded_states = self.layer_norm(hidden_states)
if conditioning_emb is not None:
forwarded_states = self.film(forwarded_states, conditioning_emb)
forwarded_states = self.DenseReluDense(forwarded_states)
hidden_states = hidden_states + self.dropout(forwarded_states)
return hidden_states
class T5DenseGatedActDense(nn.Module):
r"""
T5 style feed-forward layer with gated activations and dropout.
Args:
d_model (`int`):
Size of the input hidden states.
d_ff (`int`):
Size of the intermediate feed-forward layer.
dropout_rate (`float`):
Dropout probability.
"""
def __init__(self, d_model: int, d_ff: int, dropout_rate: float):
super().__init__()
self.wi_0 = nn.Linear(d_model, d_ff, bias=False)
self.wi_1 = nn.Linear(d_model, d_ff, bias=False)
self.wo = nn.Linear(d_ff, d_model, bias=False)
self.dropout = nn.Dropout(dropout_rate)
self.act = NewGELUActivation()
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class T5LayerNorm(nn.Module):
r"""
T5 style layer normalization module.
Args:
hidden_size (`int`):
Size of the input hidden states.
eps (`float`, `optional`, defaults to `1e-6`):
A small value used for numerical stability to avoid dividing by zero.
"""
def __init__(self, hidden_size: int, eps: float = 1e-6):
"""
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
# half-precision inputs is done in fp32
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
# convert into half-precision if necessary
if self.weight.dtype in [torch.float16, torch.bfloat16]:
hidden_states = hidden_states.to(self.weight.dtype)
return self.weight * hidden_states
class NewGELUActivation(nn.Module):
"""
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
"""
def forward(self, input: torch.Tensor) -> torch.Tensor:
return 0.5 * input * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * torch.pow(input, 3.0))))
class T5FiLMLayer(nn.Module):
"""
T5 style FiLM Layer.
Args:
in_features (`int`):
Number of input features.
out_features (`int`):
Number of output features.
"""
def __init__(self, in_features: int, out_features: int):
super().__init__()
self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
emb = self.scale_bias(conditioning_emb)
scale, shift = torch.chunk(emb, 2, -1)
x = x * (1 + scale) + shift
return x

View File

@@ -0,0 +1,458 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
import torch.nn.functional as F
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, is_torch_version
from ..attention import BasicTransformerBlock
from ..embeddings import ImagePositionalEmbeddings, PatchEmbed, PixArtAlphaTextProjection
from ..lora import LoRACompatibleConv, LoRACompatibleLinear
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
@dataclass
class Transformer2DModelOutput(BaseOutput):
"""
The output of [`Transformer2DModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
distributions for the unnoised latent pixels.
"""
sample: torch.FloatTensor
class Transformer2DModel(ModelMixin, ConfigMixin):
"""
A 2D Transformer model for image-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
num_vector_embeds (`int`, *optional*):
The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
Includes the class for the masked latent pixel.
activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
num_embeds_ada_norm ( `int`, *optional*):
The number of diffusion steps used during training. Pass if at least one of the norm_layers is
`AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
added to the hidden states.
During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlocks` attention should contain a bias parameter.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
num_vector_embeds: Optional[int] = None,
patch_size: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
double_self_attention: bool = False,
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
attention_type: str = "default",
caption_channels: int = None,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
# 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
# Define whether input is continuous or discrete depending on configuration
self.is_input_continuous = (in_channels is not None) and (patch_size is None)
self.is_input_vectorized = num_vector_embeds is not None
self.is_input_patches = in_channels is not None and patch_size is not None
if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
deprecation_message = (
f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
" incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
" Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
" results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
" would be very nice if you could open a Pull request for the `transformer/config.json` file"
)
deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
norm_type = "ada_norm"
if self.is_input_continuous and self.is_input_vectorized:
raise ValueError(
f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
" sure that either `in_channels` or `num_vector_embeds` is None."
)
elif self.is_input_vectorized and self.is_input_patches:
raise ValueError(
f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
" sure that either `num_vector_embeds` or `num_patches` is None."
)
elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
raise ValueError(
f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
)
# 2. Define input layers
if self.is_input_continuous:
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = linear_cls(in_channels, inner_dim)
else:
self.proj_in = conv_cls(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
self.height = sample_size
self.width = sample_size
self.num_vector_embeds = num_vector_embeds
self.num_latent_pixels = self.height * self.width
self.latent_image_embedding = ImagePositionalEmbeddings(
num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
)
elif self.is_input_patches:
assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
self.height = sample_size
self.width = sample_size
self.patch_size = patch_size
interpolation_scale = self.config.sample_size // 64 # => 64 (= 512 pixart) has interpolation scale 1
interpolation_scale = max(interpolation_scale, 1)
self.pos_embed = PatchEmbed(
height=sample_size,
width=sample_size,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=inner_dim,
interpolation_scale=interpolation_scale,
)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
only_cross_attention=only_cross_attention,
double_self_attention=double_self_attention,
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
attention_type=attention_type,
)
for d in range(num_layers)
]
)
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
if self.is_input_continuous:
# TODO: should use out_channels for continuous projections
if use_linear_projection:
self.proj_out = linear_cls(inner_dim, in_channels)
else:
self.proj_out = conv_cls(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
elif self.is_input_vectorized:
self.norm_out = nn.LayerNorm(inner_dim)
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
elif self.is_input_patches and norm_type != "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
elif self.is_input_patches and norm_type == "ada_norm_single":
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
# 5. PixArt-Alpha blocks.
self.adaln_single = None
self.use_additional_conditions = False
if norm_type == "ada_norm_single":
self.use_additional_conditions = self.config.sample_size == 128
# TODO(Sayak, PVP) clean this, for now we use sample size to determine whether to use
# additional conditions until we find better name
self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
self.caption_projection = None
if caption_channels is not None:
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
timestep: Optional[torch.LongTensor] = None,
added_cond_kwargs: Dict[str, torch.Tensor] = None,
class_labels: Optional[torch.LongTensor] = None,
cross_attention_kwargs: Dict[str, Any] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
The [`Transformer2DModel`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input `hidden_states`.
encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
cross_attention_kwargs ( `Dict[str, Any]`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
attention_mask ( `torch.Tensor`, *optional*):
An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
negative values to the attention scores corresponding to "discard" tokens.
encoder_attention_mask ( `torch.Tensor`, *optional*):
Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
* Mask `(batch, sequence_length)` True = keep, False = discard.
* Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
above. This bias will be added to the cross-attention scores.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
if attention_mask is not None and attention_mask.ndim == 2:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# Retrieve lora scale.
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
# 1. Input
if self.is_input_continuous:
batch, _, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
hidden_states = (
self.proj_in(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_in(hidden_states)
)
elif self.is_input_vectorized:
hidden_states = self.latent_image_embedding(hidden_states)
elif self.is_input_patches:
height, width = hidden_states.shape[-2] // self.patch_size, hidden_states.shape[-1] // self.patch_size
hidden_states = self.pos_embed(hidden_states)
if self.adaln_single is not None:
if self.use_additional_conditions and added_cond_kwargs is None:
raise ValueError(
"`added_cond_kwargs` cannot be None when using additional conditions for `adaln_single`."
)
batch_size = hidden_states.shape[0]
timestep, embedded_timestep = self.adaln_single(
timestep, added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Blocks
if self.caption_projection is not None:
batch_size = hidden_states.shape[0]
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
for block in self.transformer_blocks:
if self.training and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
attention_mask,
encoder_hidden_states,
encoder_attention_mask,
timestep,
cross_attention_kwargs,
class_labels,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
else:
hidden_states = (
self.proj_out(hidden_states, scale=lora_scale)
if not USE_PEFT_BACKEND
else self.proj_out(hidden_states)
)
hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
elif self.is_input_vectorized:
hidden_states = self.norm_out(hidden_states)
logits = self.out(hidden_states)
# (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
logits = logits.permute(0, 2, 1)
# log(p(x_0))
output = F.log_softmax(logits.double(), dim=1).float()
if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
elif self.config.norm_type == "ada_norm_single":
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# unpatchify
if self.adaln_single is None:
height = width = int(hidden_states.shape[1] ** 0.5)
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)

View File

@@ -0,0 +1,379 @@
# Copyright 2023 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Dict, Optional
import torch
from torch import nn
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..attention import BasicTransformerBlock, TemporalBasicTransformerBlock
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..resnet import AlphaBlender
@dataclass
class TransformerTemporalModelOutput(BaseOutput):
"""
The output of [`TransformerTemporalModel`].
Args:
sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
The hidden states output conditioned on `encoder_hidden_states` input.
"""
sample: torch.FloatTensor
class TransformerTemporalModel(ModelMixin, ConfigMixin):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
attention_bias (`bool`, *optional*):
Configure if the `TransformerBlock` attention should contain a bias parameter.
sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
This is fixed during training since it is used to learn a number of position embeddings.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
activation functions.
norm_elementwise_affine (`bool`, *optional*):
Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
double_self_attention (`bool`, *optional*):
Configure if each `TransformerBlock` should contain two self-attention layers.
positional_embeddings: (`str`, *optional*):
The type of positional embeddings to apply to the sequence input before passing use.
num_positional_embeddings: (`int`, *optional*):
The maximum length of the sequence over which to apply positional embeddings.
"""
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
sample_size: Optional[int] = None,
activation_fn: str = "geglu",
norm_elementwise_affine: bool = True,
double_self_attention: bool = True,
positional_embeddings: Optional[str] = None,
num_positional_embeddings: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
double_self_attention=double_self_attention,
norm_elementwise_affine=norm_elementwise_affine,
positional_embeddings=positional_embeddings,
num_positional_embeddings=num_positional_embeddings,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.LongTensor] = None,
timestep: Optional[torch.LongTensor] = None,
class_labels: torch.LongTensor = None,
num_frames: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
) -> TransformerTemporalModelOutput:
"""
The [`TransformerTemporal`] forward method.
Args:
hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
Input hidden_states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
`AdaLayerZeroNorm`.
num_frames (`int`, *optional*, defaults to 1):
The number of frames to be processed per batch. This is used to reshape the hidden states.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, channel, height, width = hidden_states.shape
batch_size = batch_frames // num_frames
residual = hidden_states
hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
hidden_states = self.proj_in(hidden_states)
# 2. Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
cross_attention_kwargs=cross_attention_kwargs,
class_labels=class_labels,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = (
hidden_states[None, None, :]
.reshape(batch_size, height, width, num_frames, channel)
.permute(0, 3, 4, 1, 2)
.contiguous()
)
hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)
class TransformerSpatioTemporalModel(nn.Module):
"""
A Transformer model for video-like data.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
out_channels (`int`, *optional*):
The number of channels in the output (specify if the input is **continuous**).
num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
"""
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: int = 320,
out_channels: Optional[int] = None,
num_layers: int = 1,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
self.inner_dim = inner_dim
# 2. Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
self.proj_in = nn.Linear(in_channels, inner_dim)
# 3. Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for d in range(num_layers)
]
)
time_mix_inner_dim = inner_dim
self.temporal_transformer_blocks = nn.ModuleList(
[
TemporalBasicTransformerBlock(
inner_dim,
time_mix_inner_dim,
num_attention_heads,
attention_head_dim,
cross_attention_dim=cross_attention_dim,
)
for _ in range(num_layers)
]
)
time_embed_dim = in_channels * 4
self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
self.time_proj = Timesteps(in_channels, True, 0)
self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
# 4. Define output layers
self.out_channels = in_channels if out_channels is None else out_channels
# TODO: should use out_channels for continuous projections
self.proj_out = nn.Linear(inner_dim, in_channels)
self.gradient_checkpointing = False
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
image_only_indicator: Optional[torch.Tensor] = None,
return_dict: bool = True,
):
"""
Args:
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
Input hidden_states.
num_frames (`int`):
The number of frames to be processed per batch. This is used to reshape the hidden states.
encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
self-attention.
image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
A tensor indicating whether the input contains only images. 1 indicates that the input contains only
images, 0 indicates that the input contains video frames.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
tuple.
Returns:
[`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
returned, otherwise a `tuple` where the first element is the sample tensor.
"""
# 1. Input
batch_frames, _, height, width = hidden_states.shape
num_frames = image_only_indicator.shape[-1]
batch_size = batch_frames // num_frames
time_context = encoder_hidden_states
time_context_first_timestep = time_context[None, :].reshape(
batch_size, num_frames, -1, time_context.shape[-1]
)[:, 0]
time_context = time_context_first_timestep[None, :].broadcast_to(
height * width, batch_size, 1, time_context.shape[-1]
)
time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
residual = hidden_states
hidden_states = self.norm(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
hidden_states = self.proj_in(hidden_states)
num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
num_frames_emb = num_frames_emb.repeat(batch_size, 1)
num_frames_emb = num_frames_emb.reshape(-1)
t_emb = self.time_proj(num_frames_emb)
# `Timesteps` does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=hidden_states.dtype)
emb = self.time_pos_embed(t_emb)
emb = emb[:, None, :]
# 2. Blocks
for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
if self.training and self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(
block,
hidden_states,
None,
encoder_hidden_states,
None,
use_reentrant=False,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
)
hidden_states_mix = hidden_states
hidden_states_mix = hidden_states_mix + emb
hidden_states_mix = temporal_block(
hidden_states_mix,
num_frames=num_frames,
encoder_hidden_states=time_context,
)
hidden_states = self.time_mixer(
x_spatial=hidden_states,
x_temporal=hidden_states_mix,
image_only_indicator=image_only_indicator,
)
# 3. Output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
if not return_dict:
return (output,)
return TransformerTemporalModelOutput(sample=output)

View File

@@ -22,7 +22,6 @@ from ...utils import is_torch_version, logging
from ...utils.torch_utils import apply_freeu
from ..activations import get_activation
from ..attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
from ..dual_transformer_2d import DualTransformer2DModel
from ..normalization import AdaGroupNorm
from ..resnet import (
Downsample2D,
@@ -34,7 +33,8 @@ from ..resnet import (
ResnetBlockCondNorm2D,
Upsample2D,
)
from ..transformer_2d import Transformer2DModel
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
logger = logging.get_logger(__name__) # pylint: disable=invalid-name

View File

@@ -20,7 +20,6 @@ from torch import nn
from ...utils import is_torch_version
from ...utils.torch_utils import apply_freeu
from ..attention import Attention
from ..dual_transformer_2d import DualTransformer2DModel
from ..resnet import (
Downsample2D,
ResnetBlock2D,
@@ -28,8 +27,9 @@ from ..resnet import (
TemporalConvLayer,
Upsample2D,
)
from ..transformer_2d import Transformer2DModel
from ..transformer_temporal import (
from ..transformers.dual_transformer_2d import DualTransformer2DModel
from ..transformers.transformer_2d import Transformer2DModel
from ..transformers.transformer_temporal import (
TransformerSpatioTemporalModel,
TransformerTemporalModel,
)

View File

@@ -33,7 +33,7 @@ from ..attention_processor import (
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_3d_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,

View File

@@ -29,7 +29,7 @@ from ..attention_processor import (
)
from ..embeddings import TimestepEmbedding, Timesteps
from ..modeling_utils import ModelMixin
from ..transformer_temporal import TransformerTemporalModel
from ..transformers.transformer_temporal import TransformerTemporalModel
from .unet_2d_blocks import UNetMidBlock2DCrossAttn
from .unet_2d_condition import UNet2DConditionModel
from .unet_3d_blocks import (

View File

@@ -35,7 +35,7 @@ from ...models.embeddings import (
)
from ...models.modeling_utils import ModelMixin
from ...models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from ...models.transformer_2d import Transformer2DModel
from ...models.transformers.transformer_2d import Transformer2DModel
from ...models.unets.unet_2d_blocks import DownBlock2D, UpBlock2D
from ...models.unets.unet_2d_condition import UNet2DConditionOutput
from ...utils import BaseOutput, is_torch_version, logging

View File

@@ -19,7 +19,6 @@ from ....models.attention_processor import (
AttnAddedKVProcessor2_0,
AttnProcessor,
)
from ....models.dual_transformer_2d import DualTransformer2DModel
from ....models.embeddings import (
GaussianFourierProjection,
ImageHintTimeEmbedding,
@@ -32,7 +31,8 @@ from ....models.embeddings import (
Timesteps,
)
from ....models.resnet import ResnetBlockCondNorm2D
from ....models.transformer_2d import Transformer2DModel
from ....models.transformers.dual_transformer_2d import DualTransformer2DModel
from ....models.transformers.transformer_2d import Transformer2DModel
from ....models.unets.unet_2d_condition import UNet2DConditionOutput
from ....utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from ....utils.torch_utils import apply_freeu

View File

@@ -10,7 +10,7 @@ from ...models.attention import FeedForward
from ...models.attention_processor import Attention
from ...models.embeddings import TimestepEmbedding, Timesteps, get_2d_sincos_pos_embed
from ...models.normalization import AdaLayerNorm
from ...models.transformer_2d import Transformer2DModelOutput
from ...models.transformers.transformer_2d import Transformer2DModelOutput
from ...utils import logging

View File

@@ -24,7 +24,7 @@ from diffusers.models.attention import GEGLU, AdaLayerNorm, ApproximateGELU
from diffusers.models.embeddings import get_timestep_embedding
from diffusers.models.lora import LoRACompatibleLinear
from diffusers.models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
from diffusers.models.transformer_2d import Transformer2DModel
from diffusers.models.transformers.transformer_2d import Transformer2DModel
from diffusers.utils.testing_utils import (
backend_manual_seed,
require_torch_accelerator_with_fp64,

View File

View File

@@ -30,7 +30,7 @@ from diffusers.utils.testing_utils import (
torch_device,
)
from .test_modeling_common import ModelTesterMixin
from ..test_modeling_common import ModelTesterMixin
enable_full_determinism()